61 Commits

Author SHA1 Message Date
polydez
2a5b8ffb21 feat: implement functionality needed for computing openings for recent blocks (#367)
* refactor: make `InnerNode` and `NodeMutation` public
* feat: implement serialization for `LeafIndex`
2025-01-24 17:32:30 -08:00
polydez
589839fef1 feat: reverse mutations generation, mutations serialization (#355)
* feat: revert mutations generation, mutations serialization
* tests: check both `apply_mutations` and `apply_mutations_with_reversion`
* feat: add `num_leaves` method for `Smt`
* refactor: improve ad-hoc benchmarks
* chore: update crate version to v0.13.1
2024-12-26 18:16:38 -08:00
crStiv
1444bbc0f2 fix: typos of different importance (#359) 2024-12-16 10:27:51 -08:00
Bobbin Threadbare
c64f43b262 chore: merge v0.13.0 release 2024-11-24 22:36:08 -08:00
Bobbin Threadbare
1867f842d3 chore: update changelog 2024-11-24 22:26:51 -08:00
Al-Kindi-0
e1072ecc7f chore: update to winterfell dependencies to 0.11 (#346) 2024-11-24 22:20:19 -08:00
Bobbin Threadbare
063ad49afd chore: update crate version to v0.13.0 2024-11-21 15:56:55 -08:00
Philipp Gackstatter
a27f9ad828 refactor: use thiserror to derive errors and update error messages (#344) 2024-11-21 15:52:20 -08:00
Al-Kindi-0
50dd6bda19 fix: skip using the field element containing the proof-of-work (#343) 2024-11-18 00:16:27 -08:00
Bobbin Threadbare
3909b01993 chore: merge v0.12.0 release from 0xPolygonMiden/next 2024-10-30 15:25:34 -07:00
Bobbin Threadbare
ee20a49953 chore: increment crate version to v0.12.0 and update changelog 2024-10-30 15:04:08 -07:00
Al-Kindi-0
0d75e3593b chore: migrate to Winterfell v0.10.0 release (#338) 2024-10-29 15:02:46 -07:00
Bobbin Threadbare
d74e746a7f chore: merge v0.11.0 release 2024-10-17 23:26:04 -07:00
Bobbin Threadbare
689cc93ed1 chore: update crate version to v0.11.0 and set MSRV to 1.82 2024-10-17 23:16:41 -07:00
Bobbin Threadbare
7970d3a736 Merge branch 'main' into next 2024-10-17 20:53:09 -07:00
Al-Kindi-0
a734dace1e feat: update RPO's padding rule to use that in the xHash paper (#318) 2024-10-17 20:49:44 -07:00
Andrey Khmuro
940cc04670 feat: add Smt::is_empty (#337) 2024-10-17 14:27:50 -07:00
Andrey Khmuro
e82baa35bb feat: return error instead of panic during MMR verification (#335) 2024-10-17 07:23:29 -07:00
Bobbin Threadbare
876d1bf97a chore: update crate version v0.10.3 2024-09-26 09:37:34 -07:00
Philipp Gackstatter
8adc0ab418 feat: implement get_size_hint for Smt (#331) 2024-09-26 09:13:50 -07:00
Bobbin Threadbare
c2eb38c236 chore: increment crate version to v0.10.2 2024-09-25 03:05:33 -07:00
Philipp Gackstatter
a924ac6b81 feat: Add size hint for digests (#330) 2024-09-25 03:03:31 -07:00
Bobbin Threadbare
e214608c85 fix: bug introduced due to merging 2024-09-13 11:10:34 -07:00
Bobbin Threadbare
c44ccd9dec Merge branch 'main' into next 2024-09-13 11:01:04 -07:00
Bobbin Threadbare
e34900c7d8 chore: update version to v0.10.1 2024-09-13 10:58:06 -07:00
Santiago Pittella
2b184cd4ca feat: add de/serialization to InOrderIndex and PartialMmr (#329) 2024-09-13 08:47:46 -07:00
Bobbin Threadbare
913384600d chore: fix typos 2024-09-11 16:52:21 -07:00
Qyriad
ae807a47ae feat: implement transactional Smt insertion (#327)
* feat(smt): impl constructing leaves that don't yet exist

This commit implements 'prospective leaf construction' -- computing
sparse Merkle tree leaves for a key-value insertion without actually
performing that insertion.

For SimpleSmt, this is trivial, since the leaf type is simply the value
being inserted.

For the full Smt, the new leaf payload depends on the existing payload
in that leaf. Since almost all leaves are very small, we can just clone
the leaf and modify a copy.

This will allow us to perform more general prospective changes on Merkle
trees.

* feat(smt): export get_value() in the trait

* feat(smt): implement generic prospective insertions

This commit adds two methods to SparseMerkleTree: compute_mutations()
and apply_mutations(), which respectively create and consume a new
MutationSet type. This type represents as set of changes to a
SparseMerkleTree that haven't happened yet, and can be queried on to
ensure a set of insertions result in the correct tree root before
finalizing and committing the mutation.

This is a direct step towards issue 222, and will directly enable
removing Merkle tree clones in miden-node InnerState::apply_block().

As part of this change, SparseMerkleTree now requires its Key to be Ord
and its Leaf to be Clone (both bounds which were already met by existing
implementations). The Ord bound could instead be changed to Eq + Hash,
if MutationSet were changed to use a HashMap instead of a BTreeMap.

* chore(smt): refactor empty node construction to helper function
2024-09-11 16:49:57 -07:00
Paul-Henry Kajfasz
f4a9d5b027 Merge pull request #323 from 0xPolygonMiden/phklive-consistent-ci
Update `Makefile` and `CI`
2024-08-22 08:22:20 -07:00
Paul-Henry Kajfasz
ee42d87121 Replace i. by 1. 2024-08-22 16:14:19 +01:00
Paul-Henry Kajfasz
b1cb2b6ec3 Fix comments 2024-08-22 15:21:59 +01:00
Paul-Henry Kajfasz
e4a9a2ac00 Updated test in workflow 2024-08-21 16:53:28 +01:00
Paul-Henry Kajfasz
c5077b1683 updated readme 2024-08-21 14:18:41 +01:00
Paul-Henry Kajfasz
2e74028fd4 Updated makefile 2024-08-21 14:11:17 +01:00
Paul-Henry Kajfasz
8bf6ef890d fmt 2024-08-21 14:04:23 +01:00
Paul-Henry Kajfasz
e2aeb25e01 Updated doc comments 2024-08-21 14:03:43 +01:00
Paul-Henry Kajfasz
790846cc73 Merge next 2024-08-21 09:29:39 +01:00
Paul-Henry Kajfasz
4cb6bed428 Updated changelog + added release to no-std 2024-08-19 14:37:58 +01:00
Bobbin Threadbare
a12e62ff22 feat: improve MMR api (#324) 2024-08-18 09:35:12 -07:00
Paul-Henry Kajfasz
9aa4987858 Merge branch 'phklive-consistent-ci' of github.com:0xPolygonMiden/crypto into phklive-consistent-ci 2024-08-16 17:29:29 -07:00
Paul-Henry Kajfasz
70a0a1e970 Removed Makefile.toml 2024-08-16 17:29:09 -07:00
Paul-Henry Kajfasz
025fbb66a9 Update README.md change miden-crypto to crypto 2024-08-17 01:21:19 +01:00
Paul-Henry Kajfasz
5ee5e8554b Ran pre-commit 2024-08-16 16:12:17 -07:00
Paul-Henry Kajfasz
ac3c6976bd Updated Changelog + pre commit 2024-08-16 16:09:51 -07:00
Paul-Henry Kajfasz
374a10f340 Updated ci + added scripts 2024-08-16 15:32:03 -07:00
Paul-Henry Kajfasz
ad0f472708 Updated Makefile and Readme 2024-08-16 15:07:27 -07:00
Bobbin Threadbare
8bb893345b chore: update rust version badge 2024-08-06 17:00:17 -07:00
Bobbin Threadbare
d92fae7f82 chore: update rust version badge 2024-08-06 16:59:31 -07:00
Bobbin Threadbare
b171575776 merge v0.10.0 release 2024-08-06 16:58:00 -07:00
Bobbin Threadbare
dfdd5f722f chore: fix lints 2024-08-06 16:52:46 -07:00
Bobbin Threadbare
9f63b50510 chore: increment crate version to v0.10.0 and update changelog 2024-08-06 16:42:50 -07:00
Elias Rad
d6ab367d32 chore: fix typos (#321) 2024-07-24 11:35:57 -07:00
Al-Kindi-0
b06cfa3c03 docs: update RPO with a comment on security given domain separation (#320) 2024-06-04 22:54:51 -07:00
Al-Kindi-0
8556c8fc43 fix: encoding Falcon secret key basis polynomials (#319) 2024-05-28 23:20:28 -07:00
Augusto Hack
78ac70120d fix: hex_to_bytes can be used for data besides RpoDigests (#317) 2024-05-13 13:13:02 -07:00
Bobbin Threadbare
ccde10af13 chore: update changelog 2024-05-12 03:17:06 +08:00
Al-Kindi-0
f967211b5a feat: migrate to new Winterfell (#315) 2024-05-12 03:09:27 +08:00
Augusto Hack
d58c717956 rpo/rpx: export digest error enum (#313) 2024-05-12 03:09:24 +08:00
Augusto Hack
c0743adac9 Rpo256: Add RpoDigest conversions (#311) 2024-05-12 03:09:21 +08:00
Bobbin Threadbare
f72add58cd chore: increment crate version to v0.9.3 and update changelog 2024-04-24 01:02:47 -07:00
Menko
63f97e5621 feat: add rpx random coin (#307) 2024-04-24 01:02:47 -07:00
75 changed files with 3869 additions and 1577 deletions

3
.config/nextest.toml Normal file
View File

@@ -0,0 +1,3 @@
[profile.default]
failure-output = "immediate-final"
fail-fast = false

25
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
# Runs build related jobs.
name: build
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
no-std:
name: Build for no-std
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
steps:
- uses: actions/checkout@main
- name: Build for no-std
run: |
rustup update --no-self-update ${{ matrix.toolchain }}
rustup target add wasm32-unknown-unknown
make build-no-std

23
.github/workflows/changelog.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
# Runs changelog related jobs.
# CI job heavily inspired by: https://github.com/tarides/changelog-check-action
name: changelog
on:
pull_request:
types: [opened, reopened, synchronize, labeled, unlabeled]
jobs:
changelog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@main
with:
fetch-depth: 0
- name: Check for changes in changelog
env:
BASE_REF: ${{ github.event.pull_request.base.ref }}
NO_CHANGELOG_LABEL: ${{ contains(github.event.pull_request.labels.*.name, 'no changelog') }}
run: ./scripts/check-changelog.sh "${{ inputs.changelog }}"
shell: bash

View File

@@ -1,31 +0,0 @@
# Runs documentation related jobs.
name: doc
on:
push:
branches:
- main
pull_request:
types: [opened, reopened, synchronize]
jobs:
docs:
name: Verify the docs on ${{matrix.toolchain}}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install rust
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
override: true
- uses: davidB/rust-cargo-make@v1
- name: cargo make - doc
run: cargo make doc

View File

@@ -4,63 +4,50 @@ name: lint
on: on:
push: push:
branches: branches: [main, next]
- main
pull_request: pull_request:
types: [opened, reopened, synchronize] types: [opened, reopened, synchronize]
jobs: jobs:
clippy:
name: clippy nightly on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Clippy
run: |
rustup update --no-self-update nightly
rustup +nightly component add clippy
make clippy
rustfmt:
name: rustfmt check nightly on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Rustfmt
run: |
rustup update --no-self-update nightly
rustup +nightly component add rustfmt
make format-check
doc:
name: doc stable on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Build docs
run: |
rustup update --no-self-update
make doc
version: version:
name: check rust version consistency name: check rust version consistency
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@main
with: with:
profile: minimal profile: minimal
override: true override: true
- name: check rust versions - name: check rust versions
run: ./scripts/check-rust-version.sh run: ./scripts/check-rust-version.sh
rustfmt:
name: rustfmt ${{matrix.toolchain}} on ${{matrix.os}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [nightly]
os: [ubuntu]
steps:
- uses: actions/checkout@v4
- name: Install minimal Rust with rustfmt
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{matrix.toolchain}}
components: rustfmt
override: true
- uses: davidB/rust-cargo-make@v1
- name: cargo make - format-check
run: cargo make format-check
clippy:
name: clippy ${{matrix.toolchain}} on ${{matrix.os}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable]
os: [ubuntu]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install minimal Rust with clippy
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{matrix.toolchain}}
components: clippy
override: true
- uses: davidB/rust-cargo-make@v1
- name: cargo make - clippy
run: cargo make clippy

View File

@@ -1,32 +0,0 @@
# Runs no-std related jobs.
name: no-std
on:
push:
branches:
- main
pull_request:
types: [opened, reopened, synchronize]
jobs:
no-std:
name: build ${{matrix.toolchain}} no-std for wasm32-unknown-unknown
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- name: Install rust
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
override: true
- run: rustup target add wasm32-unknown-unknown
- uses: davidB/rust-cargo-make@v1
- name: cargo make - build-no-std
run: cargo make build-no-std

View File

@@ -1,34 +1,28 @@
# Runs testing related jobs # Runs test related jobs.
name: test name: test
on: on:
push: push:
branches: branches: [main, next]
- main
pull_request: pull_request:
types: [opened, reopened, synchronize] types: [opened, reopened, synchronize]
jobs: jobs:
test: test:
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}} name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
runs-on: ${{matrix.os}}-latest runs-on: ${{matrix.os}}-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
toolchain: [stable, nightly] toolchain: [stable, nightly]
os: [ubuntu] os: [ubuntu]
features: ["test", "test-no-default-features"] args: [default, no-std]
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@main
with: - uses: taiki-e/install-action@nextest
submodules: recursive - name: Perform tests
- name: Install rust run: |
uses: actions-rs/toolchain@v1 rustup update --no-self-update ${{matrix.toolchain}}
with: make test-${{matrix.args}}
toolchain: ${{matrix.toolchain}}
override: true
- uses: davidB/rust-cargo-make@v1
- name: cargo make - test
run: cargo make ${{matrix.features}}

View File

@@ -1,43 +1,34 @@
# See https://pre-commit.com for more information # See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0 rev: v4.6.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: end-of-file-fixer - id: end-of-file-fixer
- id: check-yaml - id: check-yaml
- id: check-json - id: check-json
- id: check-toml - id: check-toml
- id: pretty-format-json - id: pretty-format-json
- id: check-added-large-files - id: check-added-large-files
- id: check-case-conflict - id: check-case-conflict
- id: check-executables-have-shebangs - id: check-executables-have-shebangs
- id: check-merge-conflict - id: check-merge-conflict
- id: detect-private-key - id: detect-private-key
- repo: https://github.com/hackaugusto/pre-commit-cargo - repo: local
rev: v1.0.0 hooks:
hooks: - id: lint
# Allows cargo fmt to modify the source code prior to the commit name: Make lint
- id: cargo stages: [commit]
name: Cargo fmt language: rust
args: ["+stable", "fmt", "--all"] entry: make lint
stages: [commit] - id: doc
# Requires code to be properly formatted prior to pushing upstream name: Make doc
- id: cargo stages: [commit]
name: Cargo fmt --check language: rust
args: ["+stable", "fmt", "--all", "--check"] entry: make doc
stages: [push, manual] - id: check
- id: cargo name: Make check
name: Cargo check --all-targets stages: [commit]
args: ["+stable", "check", "--all-targets"] language: rust
- id: cargo entry: make check
name: Cargo check --all-targets --no-default-features
args: ["+stable", "check", "--all-targets", "--no-default-features"]
- id: cargo
name: Cargo check --all-targets --features default,std,serde
args: ["+stable", "check", "--all-targets", "--features", "default,std,serde"]
# Unlike fmt, clippy will not be automatically applied
- id: cargo
name: Cargo clippy
args: ["+nightly", "clippy", "--workspace", "--", "--deny", "clippy::all", "--deny", "warnings"]

View File

@@ -1,74 +1,125 @@
## 0.13.2 (2025-01-24)
- Made `InnerNode` and `NodeMutation` public. Implemented (de)serialization of `LeafIndex` (#367).
## 0.13.1 (2024-12-26)
- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
## 0.13.0 (2024-11-24)
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
## 0.12.0 (2024-10-30)
- [BREAKING] Updated Winterfell dependency to v0.10 (#338).
## 0.11.0 (2024-10-17)
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
- Standardized CI and Makefile across Miden repos (#323).
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
- Changed padding rule for RPO/RPX hash functions (#318).
- [BREAKING] Changed return value of the `Mmr::verify()` and `MerklePath::verify()` from `bool` to `Result<>` (#335).
- Added `is_empty()` functions to the `SimpleSmt` and `Smt` structures. Added `EMPTY_ROOT` constant to the `SparseMerkleTree` trait (#337).
## 0.10.3 (2024-09-25)
- Implement `get_size_hint` for `Smt` (#331).
## 0.10.2 (2024-09-25)
- Implement `get_size_hint` for `RpoDigest` and `RpxDigest` and expose constants for their serialized size (#330).
## 0.10.1 (2024-09-13)
- Added `Serializable` and `Deserializable` implementations for `PartialMmr` and `InOrderIndex` (#329).
## 0.10.0 (2024-08-06)
- Added more `RpoDigest` and `RpxDigest` conversions (#311).
- [BREAKING] Migrated to Winterfell v0.9 (#315).
- Fixed encoding of Falcon secret key (#319).
## 0.9.3 (2024-04-24)
- Added `RpxRandomCoin` struct (#307).
## 0.9.2 (2024-04-21) ## 0.9.2 (2024-04-21)
* Implemented serialization for the `Smt` struct (#304). - Implemented serialization for the `Smt` struct (#304).
* Fixed a bug in Falcon signature generation (#305). - Fixed a bug in Falcon signature generation (#305).
## 0.9.1 (2024-04-02) ## 0.9.1 (2024-04-02)
* Added `num_leaves()` method to `SimpleSmt` (#302). - Added `num_leaves()` method to `SimpleSmt` (#302).
## 0.9.0 (2024-03-24) ## 0.9.0 (2024-03-24)
* [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290). - [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290).
* [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285). - [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285).
* [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299). - [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299).
# 0.8.4 (2024-03-17) # 0.8.4 (2024-03-17)
* Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros). - Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros).
# 0.8.3 (2024-03-17) # 0.8.3 (2024-03-17)
* Re-added unintentionally removed re-exported liballoc macros (#292). - Re-added unintentionally removed re-exported liballoc macros (#292).
# 0.8.2 (2024-03-17) # 0.8.2 (2024-03-17)
* Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290). - Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290).
## 0.8.1 (2024-02-21) ## 0.8.1 (2024-02-21)
* Fixed clippy warnings (#280) - Fixed clippy warnings (#280)
## 0.8.0 (2024-02-14) ## 0.8.0 (2024-02-14)
* Implemented the `PartialMmr` data structure (#195). - Implemented the `PartialMmr` data structure (#195).
* Implemented RPX hash function (#201). - Implemented RPX hash function (#201).
* Added `FeltRng` and `RpoRandomCoin` (#237). - Added `FeltRng` and `RpoRandomCoin` (#237).
* Accelerated RPO/RPX hash functions using AVX512 instructions (#234). - Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
* Added `inner_nodes()` method to `PartialMmr` (#238). - Added `inner_nodes()` method to `PartialMmr` (#238).
* Improved `PartialMmr::apply_delta()` (#242). - Improved `PartialMmr::apply_delta()` (#242).
* Refactored `SimpleSmt` struct (#245). - Refactored `SimpleSmt` struct (#245).
* Replaced `TieredSmt` struct with `Smt` struct (#254, #277). - Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
* Updated Winterfell dependency to v0.8 (#275). - Updated Winterfell dependency to v0.8 (#275).
## 0.7.1 (2023-10-10) ## 0.7.1 (2023-10-10)
* Fixed RPO Falcon signature build on Windows. - Fixed RPO Falcon signature build on Windows.
## 0.7.0 (2023-10-05) ## 0.7.0 (2023-10-05)
* Replaced `MerklePathSet` with `PartialMerkleTree` (#165). - Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
* Implemented clearing of nodes in `TieredSmt` (#173). - Implemented clearing of nodes in `TieredSmt` (#173).
* Added ability to generate inclusion proofs for `TieredSmt` (#174). - Added ability to generate inclusion proofs for `TieredSmt` (#174).
* Implemented Falcon DSA (#179). - Implemented Falcon DSA (#179).
* Added conditional `serde`` support for various structs (#180). - Added conditional `serde`` support for various structs (#180).
* Implemented benchmarking for `TieredSmt` (#182). - Implemented benchmarking for `TieredSmt` (#182).
* Added more leaf traversal methods for `MerkleStore` (#185). - Added more leaf traversal methods for `MerkleStore` (#185).
* Added SVE acceleration for RPO hash function (#189). - Added SVE acceleration for RPO hash function (#189).
## 0.6.0 (2023-06-25) ## 0.6.0 (2023-06-25)
* [BREAKING] Added support for recording capabilities for `MerkleStore` (#162). - [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
* [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157). - [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
* Added initial implementation of `PartialMerkleTree` (#156). - Added initial implementation of `PartialMerkleTree` (#156).
## 0.5.0 (2023-05-26) ## 0.5.0 (2023-05-26)
* Implemented `TieredSmt` (#152, #153). - Implemented `TieredSmt` (#152, #153).
* Implemented ability to extract a subset of a `MerkleStore` (#151). - Implemented ability to extract a subset of a `MerkleStore` (#151).
* Cleaned up `SimpleSmt` interface (#149). - Cleaned up `SimpleSmt` interface (#149).
* Decoupled hashing and padding of peaks in `Mmr` (#148). - Decoupled hashing and padding of peaks in `Mmr` (#148).
* Added `inner_nodes()` to `MerkleStore` (#146). - Added `inner_nodes()` to `MerkleStore` (#146).
## 0.4.0 (2023-04-21) ## 0.4.0 (2023-04-21)
@@ -116,6 +167,6 @@
- Initial release on crates.io containing the cryptographic primitives used in Miden VM and the Miden Rollup. - Initial release on crates.io containing the cryptographic primitives used in Miden VM and the Miden Rollup.
- Hash module with the BLAKE3 and Rescue Prime Optimized hash functions. - Hash module with the BLAKE3 and Rescue Prime Optimized hash functions.
- BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output. - BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output.
- RPO is implemented with 256-bit output. - RPO is implemented with 256-bit output.
- Merkle module, with a set of data structures related to Merkle trees, implemented using the RPO hash function. - Merkle module, with a set of data structures related to Merkle trees, implemented using the RPO hash function.

469
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
[package] [package]
name = "miden-crypto" name = "miden-crypto"
version = "0.9.2" version = "0.13.2"
description = "Miden Cryptographic primitives" description = "Miden Cryptographic primitives"
authors = ["miden contributors"] authors = ["miden contributors"]
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
repository = "https://github.com/0xPolygonMiden/crypto" repository = "https://github.com/0xPolygonMiden/crypto"
documentation = "https://docs.rs/miden-crypto/0.9.2" documentation = "https://docs.rs/miden-crypto/0.13.1"
categories = ["cryptography", "no-std"] categories = ["cryptography", "no-std"]
keywords = ["miden", "crypto", "hash", "merkle"] keywords = ["miden", "crypto", "hash", "merkle"]
edition = "2021" edition = "2021"
rust-version = "1.75" rust-version = "1.82"
[[bin]] [[bin]]
name = "miden-crypto" name = "miden-crypto"
@@ -52,22 +52,24 @@ num = { version = "0.4", default-features = false, features = ["alloc", "libm"]
num-complex = { version = "0.4", default-features = false } num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false } rand = { version = "0.8", default-features = false }
rand_core = { version = "0.6", default-features = false } rand_core = { version = "0.6", default-features = false }
rand-utils = { version = "0.8", package = "winter-rand-utils", optional = true } rand-utils = { version = "0.11", package = "winter-rand-utils", optional = true }
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] } serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
sha3 = { version = "0.10", default-features = false } sha3 = { version = "0.10", default-features = false }
winter-crypto = { version = "0.8", default-features = false } thiserror = { version = "2.0", default-features = false }
winter-math = { version = "0.8", default-features = false } winter-crypto = { version = "0.11", default-features = false }
winter-utils = { version = "0.8", default-features = false } winter-math = { version = "0.11", default-features = false }
winter-utils = { version = "0.11", default-features = false }
[dev-dependencies] [dev-dependencies]
assert_matches = { version = "1.5", default-features = false }
criterion = { version = "0.5", features = ["html_reports"] } criterion = { version = "0.5", features = ["html_reports"] }
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }
hex = { version = "0.4", default-features = false, features = ["alloc"] } hex = { version = "0.4", default-features = false, features = ["alloc"] }
proptest = "1.4" proptest = "1.6"
rand_chacha = { version = "0.3", default-features = false } rand_chacha = { version = "0.3", default-features = false }
rand-utils = { version = "0.8", package = "winter-rand-utils" } rand-utils = { version = "0.11", package = "winter-rand-utils" }
seq-macro = { version = "0.3" } seq-macro = { version = "0.3" }
[build-dependencies] [build-dependencies]
cc = { version = "1.0", optional = true, features = ["parallel"] } cc = { version = "1.2", optional = true, features = ["parallel"] }
glob = "0.3" glob = "0.3"

90
Makefile Normal file
View File

@@ -0,0 +1,90 @@
.DEFAULT_GOAL := help
.PHONY: help
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
# -- variables --------------------------------------------------------------------------------------
WARNINGS=RUSTDOCFLAGS="-D warnings"
DEBUG_OVERFLOW_INFO=RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2"
# -- linting --------------------------------------------------------------------------------------
.PHONY: clippy
clippy: ## Run Clippy with configs
$(WARNINGS) cargo +nightly clippy --workspace --all-targets --all-features
.PHONY: fix
fix: ## Run Fix with configs
cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features
.PHONY: format
format: ## Run Format using nightly toolchain
cargo +nightly fmt --all
.PHONY: format-check
format-check: ## Run Format using nightly toolchain but only in check mode
cargo +nightly fmt --all --check
.PHONY: lint
lint: format fix clippy ## Run all linting tasks at once (Clippy, fixing, formatting)
# --- docs ----------------------------------------------------------------------------------------
.PHONY: doc
doc: ## Generate and check documentation
$(WARNINGS) cargo doc --all-features --keep-going --release
# --- testing -------------------------------------------------------------------------------------
.PHONY: test-default
test-default: ## Run tests with default features
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
.PHONY: test-no-std
test-no-std: ## Run tests with `no-default-features` (std)
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --no-default-features
.PHONY: test
test: test-default test-no-std ## Run all tests
# --- checking ------------------------------------------------------------------------------------
.PHONY: check
check: ## Check all targets and features for errors without code generation
cargo check --all-targets --all-features
# --- building ------------------------------------------------------------------------------------
.PHONY: build
build: ## Build with default features enabled
cargo build --release
.PHONY: build-no-std
build-no-std: ## Build without the standard library
cargo build --release --no-default-features --target wasm32-unknown-unknown
.PHONY: build-avx2
build-avx2: ## Build with avx2 support
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
.PHONY: build-sve
build-sve: ## Build with sve support
RUSTFLAGS="-C target-feature=+sve" cargo build --release
# --- benchmarking --------------------------------------------------------------------------------
.PHONY: bench
bench: ## Run crypto benchmarks
cargo bench
.PHONY: bench-smt-concurrent
bench-smt-concurrent: ## Run SMT benchmarks with concurrent feature
cargo run --release --features executable -- --size 1000000

View File

@@ -1,86 +0,0 @@
# Cargo Makefile
# -- linting --------------------------------------------------------------------------------------
[tasks.format]
toolchain = "nightly"
command = "cargo"
args = ["fmt", "--all"]
[tasks.format-check]
toolchain = "nightly"
command = "cargo"
args = ["fmt", "--all", "--", "--check"]
[tasks.clippy-default]
command = "cargo"
args = ["clippy","--workspace", "--all-targets", "--", "-D", "clippy::all", "-D", "warnings"]
[tasks.clippy-all-features]
command = "cargo"
args = ["clippy","--workspace", "--all-targets", "--all-features", "--", "-D", "clippy::all", "-D", "warnings"]
[tasks.clippy]
dependencies = [
"clippy-default",
"clippy-all-features"
]
[tasks.fix]
description = "Runs Fix"
command = "cargo"
toolchain = "nightly"
args = ["fix", "--allow-staged", "--allow-dirty", "--all-targets", "--all-features"]
[tasks.lint]
description = "Runs all linting tasks (Clippy, fixing, formatting)"
run_task = { name = ["format", "format-check", "clippy", "docs"] }
# --- docs ----------------------------------------------------------------------------------------
[tasks.doc]
env = { "RUSTDOCFLAGS" = "-D warnings" }
command = "cargo"
args = ["doc", "--all-features", "--keep-going", "--release"]
# --- testing -------------------------------------------------------------------------------------
[tasks.test]
description = "Run tests with default features"
env = { "RUSTFLAGS" = "-C debug-assertions -C overflow-checks -C debuginfo=2" }
workspace = false
command = "cargo"
args = ["test", "--release"]
[tasks.test-no-default-features]
description = "Run tests with no-default-features"
env = { "RUSTFLAGS" = "-C debug-assertions -C overflow-checks -C debuginfo=2" }
workspace = false
command = "cargo"
args = ["test", "--release", "--no-default-features"]
[tasks.test-all]
description = "Run all tests"
workspace = false
run_task = { name = ["test", "test-no-default-features"], parallel = true }
# --- building ------------------------------------------------------------------------------------
[tasks.build]
description = "Build in release mode"
command = "cargo"
args = ["build", "--release"]
[tasks.build-no-std]
description = "Build using no-std"
command = "cargo"
args = ["build", "--release", "--no-default-features", "--target", "wasm32-unknown-unknown"]
[tasks.build-avx2]
description = "Build using AVX2 acceleration"
env = { "RUSTFLAGS" = "-C target-feature=+avx2" }
command = "cargo"
args = ["build", "--release"]
[tasks.build-sve]
description = "Build with SVE acceleration"
env = { "RUSTFLAGS" = "-C target-feature=+sve" }
command = "cargo"
args = ["build", "--release"]

View File

@@ -2,84 +2,107 @@
[![LICENSE](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE) [![LICENSE](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE)
[![test](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml) [![test](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml)
[![no-std](https://github.com/0xPolygonMiden/crypto/actions/workflows/no-std.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/no-std.yml) [![build](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml)
[![RUST_VERSION](https://img.shields.io/badge/rustc-1.75+-lightgray.svg)]() [![RUST_VERSION](https://img.shields.io/badge/rustc-1.82+-lightgray.svg)](https://www.rust-lang.org/tools/install)
[![CRATE](https://img.shields.io/crates/v/miden-crypto)](https://crates.io/crates/miden-crypto) [![CRATE](https://img.shields.io/crates/v/miden-crypto)](https://crates.io/crates/miden-crypto)
This crate contains cryptographic primitives used in Polygon Miden. This crate contains cryptographic primitives used in Polygon Miden.
## Hash ## Hash
[Hash module](./src/hash) provides a set of cryptographic hash functions which are used by the Miden VM and the Miden rollup. Currently, these functions are: [Hash module](./src/hash) provides a set of cryptographic hash functions which are used by the Miden VM and the Miden rollup. Currently, these functions are:
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3. - [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs. - [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO. - [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/). For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
## Merkle ## Merkle
[Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are: [Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are:
* `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data. - `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64. - `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log. - `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64. - `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
* `PartialMmr`: a partial view of a Merkle mountain range structure. - `PartialMmr`: a partial view of a Merkle mountain range structure.
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values. - `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
* `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values. - `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state. The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
## Signatures ## Signatures
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are: [DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM. - `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the _hash-to-point_ algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator. For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator.
## Pseudo-Random Element Generator ## Pseudo-Random Element Generator
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes: [Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
* `FeltRng`: a trait for generating random field elements and random 4 field elements. - `FeltRng`: a trait for generating random field elements and random 4 field elements.
* `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait. - `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPO hash function.
- `RpxRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPX hash function.
## Make commands
We use `make` to automate building, testing, and other processes. In most cases, `make` commands are wrappers around `cargo` commands with specific arguments. You can view the list of available commands in the [Makefile](Makefile), or run the following command:
```shell
make
```
## Crate features ## Crate features
This crate can be compiled with the following features: This crate can be compiled with the following features:
* `std` - enabled by default and relies on the Rust standard library. - `std` - enabled by default and relies on the Rust standard library.
* `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. - `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections. Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
To compile with `no_std`, disable default features via `--no-default-features` flag. To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command:
```shell
make build-no-std
```
### AVX2 acceleration ### AVX2 acceleration
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example: On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
```shell ```shell
cargo make build-avx2 make build-avx2
``` ```
### SVE acceleration ### SVE acceleration
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
On platforms with [SVE](<https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)>) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
```shell ```shell
cargo make build-sve make build-sve
``` ```
## Testing ## Testing
The best way to test the library is using our `Makefile.toml` and [cargo-make](https://github.com/sagiegurari/cargo-make), this will enable you to use our pre-defined optimized testing commands: The best way to test the library is using our [Makefile](Makefile), this will enable you to use our pre-defined optimized testing commands:
```shell ```shell
cargo make test-all make test
``` ```
For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified. For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile.toml](Makefile.toml)): We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile](Makefile)):
```shell ```shell
RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release
``` ```
## License ## License
This project is [MIT licensed](./LICENSE). This project is [MIT licensed](./LICENSE).

View File

@@ -1 +0,0 @@
1.75

5
rust-toolchain.toml Normal file
View File

@@ -0,0 +1,5 @@
[toolchain]
channel = "1.82"
components = ["rustfmt", "rust-src", "clippy"]
targets = ["wasm32-unknown-unknown"]
profile = "minimal"

View File

@@ -2,20 +2,22 @@ edition = "2021"
array_width = 80 array_width = 80
attr_fn_like_width = 80 attr_fn_like_width = 80
chain_width = 80 chain_width = 80
#condense_wildcard_suffixes = true comment_width = 100
#enum_discrim_align_threshold = 40 condense_wildcard_suffixes = true
fn_call_width = 80 fn_call_width = 80
#fn_single_line = true format_code_in_doc_comments = true
#format_code_in_doc_comments = true format_macro_matchers = true
#format_macro_matchers = true group_imports = "StdExternalCrate"
#format_strings = true hex_literal_case = "Lower"
#group_imports = "StdExternalCrate" imports_granularity = "Crate"
#hex_literal_case = "Lower" match_block_trailing_comma = true
#imports_granularity = "Crate"
newline_style = "Unix" newline_style = "Unix"
#normalize_doc_attributes = true reorder_imports = true
#reorder_impl_items = true reorder_modules = true
single_line_if_else_max_width = 60 single_line_if_else_max_width = 60
single_line_let_else_max_width = 60
struct_lit_width = 40 struct_lit_width = 40
struct_variant_width = 40
use_field_init_shorthand = true use_field_init_shorthand = true
use_try_shorthand = true use_try_shorthand = true
wrap_comments = true

21
scripts/check-changelog.sh Executable file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
set -uo pipefail
CHANGELOG_FILE="${1:-CHANGELOG.md}"
if [ "${NO_CHANGELOG_LABEL}" = "true" ]; then
# 'no changelog' set, so finish successfully
echo "\"no changelog\" label has been set"
exit 0
else
# a changelog check is required
# fail if the diff is empty
if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then
>&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior
can be overridden by using the \"no changelog\" label, which is used for changes
that are trivial / explicitly stated not to require a changelog entry."
exit 1
fi
echo "The \"CHANGELOG.md\" file has been updated."
fi

View File

@@ -1,10 +1,12 @@
#!/bin/bash #!/bin/bash
# Check rust-toolchain file # Get rust-toolchain.toml file channel
TOOLCHAIN_VERSION=$(cat rust-toolchain) TOOLCHAIN_VERSION=$(grep 'channel' rust-toolchain.toml | sed -E 's/.*"(.*)".*/\1/')
# Check workspace Cargo.toml file # Get workspace Cargo.toml file rust-version
CARGO_VERSION=$(cat Cargo.toml | grep "rust-version" | cut -d '"' -f 2) CARGO_VERSION=$(grep 'rust-version' Cargo.toml | sed -E 's/.*"(.*)".*/\1/')
# Check version match
if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then
echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION" echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION"
exit 1 exit 1

View File

@@ -1,7 +1,9 @@
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
use alloc::vec::Vec; use alloc::vec::Vec;
use num::Zero; use num::Zero;
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
// HASH-TO-POINT FUNCTIONS // HASH-TO-POINT FUNCTIONS
// ================================================================================================ // ================================================================================================

View File

@@ -15,12 +15,13 @@ pub use secret_key::SecretKey;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha20Rng; use rand_chacha::ChaCha20Rng;
use winter_math::FieldElement; use winter_math::FieldElement;
use winter_utils::{Deserializable, Serializable}; use winter_utils::{Deserializable, Serializable};
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
#[test] #[test]
fn test_falcon_verification() { fn test_falcon_verification() {
let seed = [0_u8; 32]; let seed = [0_u8; 32];

View File

@@ -1,13 +1,14 @@
use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS; use alloc::string::ToString;
use core::ops::Deref;
use num::Zero;
use super::{ use super::{
super::{Rpo256, LOG_N, N, PK_LEN}, super::{Rpo256, LOG_N, N, PK_LEN},
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial, ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial,
Serializable, Signature, Word, Serializable, Signature, Word,
}; };
use alloc::string::ToString; use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS;
use core::ops::Deref;
use num::Zero;
// PUBLIC KEY // PUBLIC KEY
// ================================================================================================ // ================================================================================================
@@ -116,7 +117,7 @@ impl Deserializable for PubKeyPoly {
if acc_len >= FALCON_ENCODING_BITS { if acc_len >= FALCON_ENCODING_BITS {
acc_len -= FALCON_ENCODING_BITS; acc_len -= FALCON_ENCODING_BITS;
let w = (acc >> acc_len) & 0x3FFF; let w = (acc >> acc_len) & 0x3fff;
let element = w.try_into().map_err(|err| { let element = w.try_into().map_err(|err| {
DeserializationError::InvalidValue(format!( DeserializationError::InvalidValue(format!(
"Failed to decode public key: {err}" "Failed to decode public key: {err}"

View File

@@ -1,3 +1,11 @@
use alloc::{string::ToString, vec::Vec};
use num::Complex;
#[cfg(not(feature = "std"))]
use num::Float;
use num_complex::Complex64;
use rand::Rng;
use super::{ use super::{
super::{ super::{
math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial}, math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial},
@@ -10,13 +18,6 @@ use super::{
use crate::dsa::rpo_falcon512::{ use crate::dsa::rpo_falcon512::{
hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN, hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN,
}; };
use alloc::{string::ToString, vec::Vec};
use num::Complex;
use num_complex::Complex64;
use rand::Rng;
#[cfg(not(feature = "std"))]
use num::Float;
// CONSTANTS // CONSTANTS
// ================================================================================================ // ================================================================================================
@@ -27,6 +28,8 @@ const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
// SECRET KEY // SECRET KEY
// ================================================================================================ // ================================================================================================
/// Represents the secret key for Falcon DSA.
///
/// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each /// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each
/// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo /// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo
/// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has /// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has
@@ -217,15 +220,27 @@ impl Serializable for SecretKey {
let mut buffer = Vec::with_capacity(1281); let mut buffer = Vec::with_capacity(1281);
buffer.push(header); buffer.push(header);
let f_i8: Vec<i8> = neg_f.coefficients.iter().map(|&a| -a as i8).collect(); let f_i8: Vec<i8> = neg_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap(); let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&f_i8_encoded); buffer.extend_from_slice(&f_i8_encoded);
let g_i8: Vec<i8> = g.coefficients.iter().map(|&a| a as i8).collect(); let g_i8: Vec<i8> = g
.coefficients
.iter()
.map(|&a| FalconFelt::new(a).balanced_value() as i8)
.collect();
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap(); let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&g_i8_encoded); buffer.extend_from_slice(&g_i8_encoded);
let big_f_i8: Vec<i8> = neg_big_f.coefficients.iter().map(|&a| -a as i8).collect(); let big_f_i8: Vec<i8> = neg_big_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap(); let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&big_f_i8_encoded); buffer.extend_from_slice(&big_f_i8_encoded);
target.write_bytes(&buffer); target.write_bytes(&buffer);

View File

@@ -1,11 +1,12 @@
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
use alloc::boxed::Box; use alloc::boxed::Box;
#[cfg(not(feature = "std"))]
use num::Float;
use num::{One, Zero}; use num::{One, Zero};
use num_complex::{Complex, Complex64}; use num_complex::{Complex, Complex64};
use rand::Rng; use rand::Rng;
#[cfg(not(feature = "std"))] use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
use num::Float;
const SIGMIN: f64 = 1.2778336969128337; const SIGMIN: f64 = 1.2778336969128337;
@@ -80,11 +81,11 @@ pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
LdlTree::Branch(_ell, left, right) => { LdlTree::Branch(_ell, left, right) => {
normalize_tree(left, sigma); normalize_tree(left, sigma);
normalize_tree(right, sigma); normalize_tree(right, sigma);
} },
LdlTree::Leaf(vector) => { LdlTree::Leaf(vector) => {
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0); vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
vector[1] = Complex64::zero(); vector[1] = Complex64::zero();
} },
} }
} }
@@ -110,7 +111,7 @@ pub fn ffsampling<R: Rng>(
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1); let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
(z0, z1) (z0, z1)
} },
LdlTree::Leaf(value) => { LdlTree::Leaf(value) => {
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng); let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng); let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
@@ -118,6 +119,6 @@ pub fn ffsampling<R: Rng>(
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]), Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]), Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
) )
} },
} }
} }

View File

@@ -1,14 +1,15 @@
use super::{field::FalconFelt, polynomial::Polynomial, Inverse};
use alloc::vec::Vec; use alloc::vec::Vec;
use core::{ use core::{
f64::consts::PI, f64::consts::PI,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
}; };
use num::{One, Zero};
use num_complex::Complex64;
#[cfg(not(feature = "std"))] #[cfg(not(feature = "std"))]
use num::Float; use num::Float;
use num::{One, Zero};
use num_complex::Complex64;
use super::{field::FalconFelt, polynomial::Polynomial, Inverse};
/// Implements Cyclotomic FFT without bitreversing the outputs, and using precomputed powers of the /// Implements Cyclotomic FFT without bitreversing the outputs, and using precomputed powers of the
/// 2n-th primitive root of unity. /// 2n-th primitive root of unity.
@@ -73,7 +74,7 @@ where
rev rev
} }
/// Computes the first n powers of the 2nth root of unity, and put them in bit-reversed order. /// Computes the first n powers of the 2nd root of unity, and put them in bit-reversed order.
#[allow(dead_code)] #[allow(dead_code)]
fn bitreversed_powers(n: usize) -> Vec<Self> { fn bitreversed_powers(n: usize) -> Vec<Self> {
let psi = Self::primitive_root_of_unity(2 * n); let psi = Self::primitive_root_of_unity(2 * n);
@@ -87,7 +88,7 @@ where
array array
} }
/// Computes the first n powers of the 2nth root of unity, invert them, and put them in /// Computes the first n powers of the 2nd root of unity, invert them, and put them in
/// bit-reversed order. /// bit-reversed order.
#[allow(dead_code)] #[allow(dead_code)]
fn bitreversed_powers_inverse(n: usize) -> Vec<Self> { fn bitreversed_powers_inverse(n: usize) -> Vec<Self> {
@@ -102,7 +103,8 @@ where
array array
} }
/// Reorders the given elements in the array by reversing the binary expansions of their indices. /// Reorders the given elements in the array by reversing the binary expansions of their
/// indices.
fn bitreverse_array<T>(array: &mut [T]) { fn bitreverse_array<T>(array: &mut [T]) {
let n = array.len(); let n = array.len();
for i in 0..n { for i in 0..n {
@@ -118,19 +120,14 @@ where
/// ///
/// Arguments: /// Arguments:
/// ///
/// - a : &mut [Self] /// - a : &mut [Self] (a reference to) a mutable array of field elements which is to be
/// (a reference to) a mutable array of field elements which is to /// transformed under the FFT. The transformation happens in- place.
/// be transformed under the FFT. The transformation happens in-
/// place.
/// ///
/// - psi_rev: &[Self] /// - psi_rev: &[Self] (a reference to) an array of powers of psi, from 0 to n-1, but ordered
/// (a reference to) an array of powers of psi, from 0 to n-1, /// by bit-reversed index. Here psi is a primitive root of order 2n. You can use
/// but ordered by bit-reversed index. Here psi is a primitive root /// `Self::bitreversed_powers(psi, n)` for this purpose, but this trait implementation is not
/// of order 2n. You can use /// const. For the performance benefit you want a precompiled array, which you can get if you
/// `Self::bitreversed_powers(psi, n)` for this purpose, but this /// can get by implementing the same method and marking it "const".
/// trait implementation is not const. For the performance benefit
/// you want a precompiled array, which you can get if you can get
/// by implementing the same method and marking it "const".
fn fft(a: &mut [Self], psi_rev: &[Self]) { fn fft(a: &mut [Self], psi_rev: &[Self]) {
let n = a.len(); let n = a.len();
let mut t = n; let mut t = n;
@@ -158,20 +155,15 @@ where
/// ///
/// Arguments: /// Arguments:
/// ///
/// - a : &mut [Self] /// - a : &mut [Self] (a reference to) a mutable array of field elements which is to be
/// (a reference to) a mutable array of field elements which is to /// transformed under the IFFT. The transformation happens in- place.
/// be transformed under the IFFT. The transformation happens in-
/// place.
/// ///
/// - psi_inv_rev: &[Self] /// - psi_inv_rev: &[Self] (a reference to) an array of powers of psi^-1, from 0 to n-1, but
/// (a reference to) an array of powers of psi^-1, from 0 to n-1, /// ordered by bit-reversed index. Here psi is a primitive root of order 2n. You can use
/// but ordered by bit-reversed index. Here psi is a primitive root of /// `Self::bitreversed_powers(Self::inverse_or_zero(psi), n)` for this purpose, but this
/// order 2n. You can use /// trait implementation is not const. For the performance benefit you want a precompiled
/// `Self::bitreversed_powers(Self::inverse_or_zero(psi), n)` for /// array, which you can get if you can get by implementing the same methods and marking them
/// this purpose, but this trait implementation is not const. For /// "const".
/// the performance benefit you want a precompiled array, which you
/// can get if you can get by implementing the same methods and marking
/// them "const".
fn ifft(a: &mut [Self], psi_inv_rev: &[Self], ninv: Self) { fn ifft(a: &mut [Self], psi_inv_rev: &[Self], ninv: Self) {
let n = a.len(); let n = a.len();
let mut t = 1; let mut t = 1;

View File

@@ -1,8 +1,10 @@
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
use alloc::string::String; use alloc::string::String;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::{One, Zero}; use num::{One, Zero};
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FalconFelt(u32); pub struct FalconFelt(u32);

View File

@@ -2,17 +2,19 @@
//! //!
//! It uses and acknowledges the work in: //! It uses and acknowledges the work in:
//! //!
//! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas Pornin. //! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas
//! Pornin.
//! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec. //! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec.
use super::MODULUS;
use alloc::{string::String, vec::Vec}; use alloc::{string::String, vec::Vec};
use core::ops::MulAssign; use core::ops::MulAssign;
#[cfg(not(feature = "std"))]
use num::Float;
use num::{BigInt, FromPrimitive, One, Zero}; use num::{BigInt, FromPrimitive, One, Zero};
use num_complex::Complex64; use num_complex::Complex64;
use rand::Rng; use rand::Rng;
#[cfg(not(feature = "std"))] use super::MODULUS;
use num::Float;
mod fft; mod fft;
pub use fft::{CyclotomicFourier, FastFft}; pub use fft::{CyclotomicFourier, FastFft};
@@ -152,7 +154,7 @@ fn ntru_solve(
{ {
None None
} }
} },
} }
} }

View File

@@ -1,12 +1,18 @@
use super::{field::FalconFelt, Inverse};
use crate::dsa::rpo_falcon512::{MODULUS, N};
use crate::Felt;
use alloc::vec::Vec; use alloc::vec::Vec;
use core::default::Default; use core::{
use core::fmt::Debug; default::Default,
use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign}; fmt::Debug,
ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
};
use num::{One, Zero}; use num::{One, Zero};
use super::{field::FalconFelt, Inverse};
use crate::{
dsa::rpo_falcon512::{MODULUS, N},
Felt,
};
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct Polynomial<F> { pub struct Polynomial<F> {
pub coefficients: Vec<F>, pub coefficients: Vec<F>,
@@ -134,8 +140,8 @@ impl<
Self::new(coefficients) Self::new(coefficients)
} }
/// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 > , /// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 >
/// which corresponds to f(x^2). /// , which corresponds to f(x^2).
pub fn galois_adjoint(&self) -> Self { pub fn galois_adjoint(&self) -> Self {
Self::new( Self::new(
self.coefficients self.coefficients

View File

@@ -1,8 +1,8 @@
use core::f64::consts::LN_2; use core::f64::consts::LN_2;
use rand::Rng;
#[cfg(not(feature = "std"))] #[cfg(not(feature = "std"))]
use num::Float; use num::Float;
use rand::Rng;
/// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to /// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to
/// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation /// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation
@@ -40,18 +40,18 @@ fn approx_exp(x: f64, ccs: f64) -> u64 {
// https://eprint.iacr.org/2018/1234 // https://eprint.iacr.org/2018/1234
// https://github.com/raykzhao/gaussian // https://github.com/raykzhao/gaussian
const C: [u64; 13] = [ const C: [u64; 13] = [
0x00000004741183A3u64, 0x00000004741183a3u64,
0x00000036548CFC06u64, 0x00000036548cfc06u64,
0x0000024FDCBF140Au64, 0x0000024fdcbf140au64,
0x0000171D939DE045u64, 0x0000171d939de045u64,
0x0000D00CF58F6F84u64, 0x0000d00cf58f6f84u64,
0x000680681CF796E3u64, 0x000680681cf796e3u64,
0x002D82D8305B0FEAu64, 0x002d82d8305b0feau64,
0x011111110E066FD0u64, 0x011111110e066fd0u64,
0x0555555555070F00u64, 0x0555555555070f00u64,
0x155555555581FF00u64, 0x155555555581ff00u64,
0x400000000002B400u64, 0x400000000002b400u64,
0x7FFFFFFFFFFF4800u64, 0x7fffffffffff4800u64,
0x8000000000000000u64, 0x8000000000000000u64,
]; ];
@@ -116,9 +116,10 @@ pub(crate) fn sampler_z<R: Rng>(mu: f64, sigma: f64, sigma_min: f64, rng: &mut R
#[cfg(all(test, feature = "std"))] #[cfg(all(test, feature = "std"))]
mod test { mod test {
use alloc::vec::Vec; use alloc::vec::Vec;
use rand::RngCore;
use std::{thread::sleep, time::Duration}; use std::{thread::sleep, time::Duration};
use rand::RngCore;
use super::{approx_exp, ber_exp, sampler_z}; use super::{approx_exp, ber_exp, sampler_z};
/// RNG used only for testing purposes, whereby the produced /// RNG used only for testing purposes, whereby the produced

View File

@@ -9,9 +9,11 @@ mod keys;
mod math; mod math;
mod signature; mod signature;
pub use self::keys::{PubKeyPoly, PublicKey, SecretKey}; pub use self::{
pub use self::math::Polynomial; keys::{PubKeyPoly, PublicKey, SecretKey},
pub use self::signature::{Signature, SignatureHeader, SignaturePoly}; math::Polynomial,
signature::{Signature, SignatureHeader, SignaturePoly},
};
// CONSTANTS // CONSTANTS
// ================================================================================================ // ================================================================================================

View File

@@ -1,6 +1,8 @@
use alloc::{string::ToString, vec::Vec}; use alloc::{string::ToString, vec::Vec};
use core::ops::Deref; use core::ops::Deref;
use num::Zero;
use super::{ use super::{
hash_to_point::hash_to_point_rpo256, hash_to_point::hash_to_point_rpo256,
keys::PubKeyPoly, keys::PubKeyPoly,
@@ -8,7 +10,6 @@ use super::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256, ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256,
Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN, Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN,
}; };
use num::Zero;
// FALCON SIGNATURE // FALCON SIGNATURE
// ================================================================================================ // ================================================================================================
@@ -38,12 +39,12 @@ use num::Zero;
/// The signature is serialized as: /// The signature is serialized as:
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial /// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
/// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header /// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header
/// byte is set to `10111001` which differentiates it from the standardized instantiation of /// byte is set to `10111001` which differentiates it from the standardized instantiation of the
/// the Falcon signature. /// Falcon signature.
/// 2. 40 bytes for the nonce. /// 2. 40 bytes for the nonce.
/// 4. 625 bytes encoding the `s2` polynomial above. /// 4. 625 bytes encoding the `s2` polynomial above.
/// ///
/// The total size of the signature is (including the extended public key) is 1563 bytes. /// The total size of the signature (including the extended public key) is 1563 bytes.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature { pub struct Signature {
header: SignatureHeader, header: SignatureHeader,
@@ -355,10 +356,11 @@ fn are_coefficients_valid(x: &[i16]) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{super::SecretKey, *};
use rand::SeedableRng; use rand::SeedableRng;
use rand_chacha::ChaCha20Rng; use rand_chacha::ChaCha20Rng;
use super::{super::SecretKey, *};
#[test] #[test]
fn test_serialization_round_trip() { fn test_serialization_round_trip() {
let seed = [0_u8; 32]; let seed = [0_u8; 32];

View File

@@ -1,8 +1,8 @@
use alloc::string::String; use alloc::{string::String, vec::Vec};
use core::{ use core::{
mem::{size_of, transmute, transmute_copy}, mem::{size_of, transmute, transmute_copy},
ops::Deref, ops::Deref,
slice::from_raw_parts, slice::{self, from_raw_parts},
}; };
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher}; use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
@@ -33,6 +33,14 @@ const DIGEST20_BYTES: usize = 20;
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] #[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct Blake3Digest<const N: usize>([u8; N]); pub struct Blake3Digest<const N: usize>([u8; N]);
impl<const N: usize> Blake3Digest<N> {
pub fn digests_as_bytes(digests: &[Blake3Digest<N>]) -> &[u8] {
let p = digests.as_ptr();
let len = digests.len() * N;
unsafe { slice::from_raw_parts(p as *const u8, len) }
}
}
impl<const N: usize> Default for Blake3Digest<N> { impl<const N: usize> Default for Blake3Digest<N> {
fn default() -> Self { fn default() -> Self {
Self([0; N]) Self([0; N])
@@ -114,6 +122,10 @@ impl Hasher for Blake3_256 {
Self::hash(prepare_merge(values)) Self::hash(prepare_merge(values))
} }
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Blake3Digest(blake3::hash(Blake3Digest::digests_as_bytes(values)).into())
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new(); let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0); hasher.update(&seed.0);
@@ -174,6 +186,11 @@ impl Hasher for Blake3_192 {
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into())) Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
} }
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest { fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
Self::hash(prepare_merge(values)) Self::hash(prepare_merge(values))
} }
@@ -242,6 +259,11 @@ impl Hasher for Blake3_160 {
Self::hash(prepare_merge(values)) Self::hash(prepare_merge(values))
} }
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new(); let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0); hasher.update(&seed.0);

View File

@@ -1,8 +1,9 @@
use alloc::vec::Vec;
use proptest::prelude::*; use proptest::prelude::*;
use rand_utils::rand_vector; use rand_utils::rand_vector;
use super::*; use super::*;
use alloc::vec::Vec;
#[test] #[test]
fn blake3_hash_elements() { fn blake3_hash_elements() {

View File

@@ -1,16 +1,16 @@
//! Cryptographic hash functions used by the Miden VM and the Miden rollup. //! Cryptographic hash functions used by the Miden VM and the Miden rollup.
use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO}; use super::{CubeExtension, Felt, FieldElement, StarkField, ZERO};
pub mod blake; pub mod blake;
mod rescue; mod rescue;
pub mod rpo { pub mod rpo {
pub use super::rescue::{Rpo256, RpoDigest}; pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError};
} }
pub mod rpx { pub mod rpx {
pub use super::rescue::{Rpx256, RpxDigest}; pub use super::rescue::{Rpx256, RpxDigest, RpxDigestError};
} }
// RE-EXPORTS // RE-EXPORTS

View File

@@ -4,40 +4,43 @@ use core::arch::x86_64::*;
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs // https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
// Preliminary notes: // Preliminary notes:
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily // 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily emulated.
// emulated. The method recognizes that for a + b overflowed iff (a + b) < a: // The method recognizes that for a + b overflowed iff (a + b) < a:
// i. res_lo = a_lo + b_lo // 1. res_lo = a_lo + b_lo
// ii. carry_mask = res_lo < a_lo // 2. carry_mask = res_lo < a_lo
// iii. res_hi = a_hi + b_hi - carry_mask // 3. res_hi = a_hi + b_hi - carry_mask
//
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions // Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
// return -1 (all bits 1) for true and 0 for false. // return -1 (all bits 1) for true and 0 for false.
// //
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons // 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around // by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts // and the comparisons are unsigned and signed respectively. The shift function adds/subtracts 1
// 1 << 63 to enable this trick. // << 63 to enable this trick. Addition with carry example:
// Example: addition with carry. // 1. a_lo_s = shift(a_lo)
// i. a_lo_s = shift(a_lo) // 2. res_lo_s = a_lo_s + b_lo
// ii. res_lo_s = a_lo_s + b_lo // 3. carry_mask = res_lo_s <s a_lo_s
// iii. carry_mask = res_lo_s <s a_lo_s // 4. res_lo = shift(res_lo_s)
// iv. res_lo = shift(res_lo_s) // 5. res_hi = a_hi + b_hi - carry_mask
// v. res_hi = a_hi + b_hi - carry_mask //
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition is // The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition
// shifted if exactly one of the operands is shifted, as is the case on line ii. Line iii. // is shifted if exactly one of the operands is shifted, as is the case on
// performs a signed comparison res_lo_s <s a_lo_s on shifted values to emulate unsigned // line 2. Line 3. performs a signed comparison res_lo_s <s a_lo_s on shifted values to
// comparison res_lo <u a_lo on unshifted values. Finally, line iv. reverses the shift so the // emulate unsigned comparison res_lo <u a_lo on unshifted values. Finally, line 4. reverses the
// result can be returned. // shift so the result can be returned.
// When performing a chain of calculations, we can often save instructions by letting the shift //
// propagate through and only undoing it when necessary. For example, to compute the addition of // When performing a chain of calculations, we can often save instructions by letting
// three two-word (128-bit) numbers we can do: // the shift propagate through and only undoing it when necessary.
// i. a_lo_s = shift(a_lo) // For example, to compute the addition of three two-word (128-bit) numbers we can do:
// ii. tmp_lo_s = a_lo_s + b_lo // 1. a_lo_s = shift(a_lo)
// iii. tmp_carry_mask = tmp_lo_s <s a_lo_s // 2. tmp_lo_s = a_lo_s + b_lo
// iv. tmp_hi = a_hi + b_hi - tmp_carry_mask // 3. tmp_carry_mask = tmp_lo_s <s a_lo_s
// v. res_lo_s = tmp_lo_s + c_lo // 4. tmp_hi = a_hi + b_hi - tmp_carry_mask
// vi. res_carry_mask = res_lo_s <s tmp_lo_s // 5. res_lo_s = tmp_lo_s + c_lo vi. res_carry_mask = res_lo_s <s tmp_lo_s
// vii. res_lo = shift(res_lo_s) // 6. res_carry_mask = res_lo_s <s tmp_lo_s
// viii. res_hi = tmp_hi + c_hi - res_carry_mask // 7. res_lo = shift(res_lo_s)
// 8. res_hi = tmp_hi + c_hi - res_carry_mask
//
// Notice that the above 3-value addition still only requires two calls to shift, just like our // Notice that the above 3-value addition still only requires two calls to shift, just like our
// 2-value addition. // 2-value addition.
@@ -60,10 +63,10 @@ pub fn branch_hint() {
} }
macro_rules! map3 { macro_rules! map3 {
($f:ident::<$l:literal>, $v:ident) => { ($f:ident:: < $l:literal > , $v:ident) => {
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2)) ($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
}; };
($f:ident::<$l:literal>, $v1:ident, $v2:ident) => { ($f:ident:: < $l:literal > , $v1:ident, $v2:ident) => {
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2)) ($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
}; };
($f:ident, $v:ident) => { ($f:ident, $v:ident) => {
@@ -72,11 +75,11 @@ macro_rules! map3 {
($f:ident, $v0:ident, $v1:ident) => { ($f:ident, $v0:ident, $v1:ident) => {
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2)) ($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
}; };
($f:ident, rep $v0:ident, $v1:ident) => { ($f:ident,rep $v0:ident, $v1:ident) => {
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2)) ($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
}; };
($f:ident, $v0:ident, rep $v1:ident) => { ($f:ident, $v0:ident,rep $v1:ident) => {
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1)) ($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
}; };
} }

View File

@@ -1,26 +1,28 @@
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS // FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
// ================================================================================================ // ================================================================================================
/// This module contains helper functions as well as constants used to perform the vector-matrix //! This module contains helper functions as well as constants used to perform the vector-matrix
/// multiplication step of the Rescue prime permutation. The special form of our MDS matrix //! multiplication step of the Rescue prime permutation. The special form of our MDS matrix
/// i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product //! i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
/// of two vectors in "frequency domain". This follows from the simple fact that every circulant //! of two vectors in "frequency domain". This follows from the simple fact that every circulant
/// matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors. //! matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
/// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that //! The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
/// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, //! with explicit expressions. It also avoids, due to the form of our matrix in the frequency
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of //! domain, divisions by 2 and repeated modular reductions. This is because of our explicit choice
/// an MDS matrix that has small powers of 2 entries in frequency domain. //! of an MDS matrix that has small powers of 2 entries in frequency domain.
/// The following implementation has benefited greatly from the discussions and insights of //! The following implementation has benefited greatly from the discussions and insights of
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2 //! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
/// implementation. //! implementation.
// Rescue MDS matrix in frequency domain. // Rescue MDS matrix in frequency domain.
//
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of // More precisely, this is the output of the three 4-point (real) FFTs of the first column of
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors // the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
// and application of the final four 3-point FFT in order to get the full 12-point FFT. // and application of the final four 3-point FFT in order to get the full 12-point FFT.
// The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4. // The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4.
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to generate // The code to generate the matrix in frequency domain is based on an adaptation of a code, to
// MDS matrices efficiently in original domain, that was developed by the Polygon Zero team. // generate MDS matrices efficiently in original domain, that was developed by the Polygon Zero
// team.
const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16]; const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16];
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)]; const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)];
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1]; const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];

View File

@@ -1,8 +1,6 @@
use core::ops::Range; use core::ops::Range;
use super::{ use super::{CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ZERO};
CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO,
};
mod arch; mod arch;
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox}; pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
@@ -11,10 +9,10 @@ mod mds;
use mds::{apply_mds, MDS}; use mds::{apply_mds, MDS};
mod rpo; mod rpo;
pub use rpo::{Rpo256, RpoDigest}; pub use rpo::{Rpo256, RpoDigest, RpoDigestError};
mod rpx; mod rpx;
pub use rpx::{Rpx256, RpxDigest}; pub use rpx::{Rpx256, RpxDigest, RpxDigestError};
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;

View File

@@ -1,5 +1,7 @@
use alloc::string::String; use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref}; use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO}; use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
use crate::{ use crate::{
@@ -19,6 +21,9 @@ use crate::{
pub struct RpoDigest([Felt; DIGEST_SIZE]); pub struct RpoDigest([Felt; DIGEST_SIZE]);
impl RpoDigest { impl RpoDigest {
/// The serialized size of the digest in bytes.
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self { pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value) Self(value)
} }
@@ -31,13 +36,19 @@ impl RpoDigest {
<Self as Digest>::as_bytes(self) <Self as Digest>::as_bytes(self)
} }
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt> pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where where
I: Iterator<Item = &'a Self>, I: Iterator<Item = &'a Self>,
{ {
digests.flat_map(|d| d.0.iter()) digests.flat_map(|d| d.0.iter())
} }
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
}
/// Returns hexadecimal representation of this digest prefixed with `0x`. /// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String { pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes()) bytes_to_hex_string(self.as_bytes())
@@ -118,26 +129,145 @@ impl Randomizable for RpoDigest {
// CONVERSIONS: FROM RPO DIGEST // CONVERSIONS: FROM RPO DIGEST
// ================================================================================================ // ================================================================================================
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] { #[derive(Debug, Error)]
fn from(value: &RpoDigest) -> Self { pub enum RpoDigestError {
value.0 #[error("failed to convert digest field element to {0}")]
TypeConversion(&'static str),
#[error("failed to convert to field element: {0}")]
InvalidFieldElement(String),
}
impl TryFrom<&RpoDigest> for [bool; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
} }
} }
impl From<RpoDigest> for [Felt; DIGEST_SIZE] { impl TryFrom<RpoDigest> for [bool; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self { type Error = RpoDigestError;
value.0
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
fn to_bool(v: u64) -> Option<bool> {
if v <= 1 {
Some(v == 1)
} else {
None
}
}
Ok([
to_bool(value.0[0].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[1].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[2].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[3].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u8; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u8; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u16; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u16; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u32; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u32; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
])
} }
} }
impl From<&RpoDigest> for [u64; DIGEST_SIZE] { impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self { fn from(value: &RpoDigest) -> Self {
[ (*value).into()
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
} }
} }
@@ -152,9 +282,21 @@ impl From<RpoDigest> for [u64; DIGEST_SIZE] {
} }
} }
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
value.0
}
}
impl From<&RpoDigest> for [u8; DIGEST_BYTES] { impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpoDigest) -> Self { fn from(value: &RpoDigest) -> Self {
value.as_bytes() (*value).into()
} }
} }
@@ -164,13 +306,6 @@ impl From<RpoDigest> for [u8; DIGEST_BYTES] {
} }
} }
impl From<RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpoDigest) -> Self {
value.to_hex()
}
}
impl From<&RpoDigest> for String { impl From<&RpoDigest> for String {
/// The returned string starts with `0x`. /// The returned string starts with `0x`.
fn from(value: &RpoDigest) -> Self { fn from(value: &RpoDigest) -> Self {
@@ -178,13 +313,83 @@ impl From<&RpoDigest> for String {
} }
} }
impl From<RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpoDigest) -> Self {
value.to_hex()
}
}
// CONVERSIONS: TO RPO DIGEST // CONVERSIONS: TO RPO DIGEST
// ================================================================================================ // ================================================================================================
#[derive(Copy, Clone, Debug)] impl From<&[bool; DIGEST_SIZE]> for RpoDigest {
pub enum RpoDigestError { fn from(value: &[bool; DIGEST_SIZE]) -> Self {
/// The provided u64 integer does not fit in the field's moduli. (*value).into()
InvalidInteger, }
}
impl From<[bool; DIGEST_SIZE]> for RpoDigest {
fn from(value: [bool; DIGEST_SIZE]) -> Self {
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
}
}
impl From<&[u8; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u8; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u8; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u16; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u16; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u16; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u32; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u32; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u32; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
(*value).try_into()
}
}
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
Ok(Self([
value[0].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[1].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[2].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[3].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
]))
}
} }
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest { impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
@@ -199,6 +404,14 @@ impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
} }
} }
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest { impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError; type Error = HexParseError;
@@ -218,14 +431,6 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
} }
} }
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<&[u8]> for RpoDigest { impl TryFrom<&[u8]> for RpoDigest {
type Error = HexParseError; type Error = HexParseError;
@@ -234,33 +439,12 @@ impl TryFrom<&[u8]> for RpoDigest {
} }
} }
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
Ok(Self([
value[0].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
value[1].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
value[2].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
value[3].try_into().map_err(|_| RpoDigestError::InvalidInteger)?,
]))
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
(*value).try_into()
}
}
impl TryFrom<&str> for RpoDigest { impl TryFrom<&str> for RpoDigest {
type Error = HexParseError; type Error = HexParseError;
/// Expects the string to start with `0x`. /// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> { fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes(value).and_then(|v| v.try_into()) hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpoDigest::try_from)
} }
} }
@@ -289,6 +473,10 @@ impl Serializable for RpoDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) { fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes()); target.write_bytes(&self.as_bytes());
} }
fn get_size_hint(&self) -> usize {
Self::SERIALIZED_SIZE
}
} }
impl Deserializable for RpoDigest { impl Deserializable for RpoDigest {
@@ -325,6 +513,7 @@ impl IntoIterator for RpoDigest {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use alloc::string::String; use alloc::string::String;
use rand_utils::rand_value; use rand_utils::rand_value;
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE}; use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
@@ -342,6 +531,7 @@ mod tests {
let mut bytes = vec![]; let mut bytes = vec![];
d1.write_into(&mut bytes); d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len()); assert_eq!(DIGEST_BYTES, bytes.len());
assert_eq!(bytes.len(), d1.get_size_hint());
let mut reader = SliceReader::new(&bytes); let mut reader = SliceReader::new(&bytes);
let d2 = RpoDigest::read_from(&mut reader).unwrap(); let d2 = RpoDigest::read_from(&mut reader).unwrap();
@@ -373,44 +563,72 @@ mod tests {
Felt::new(rand_value()), Felt::new(rand_value()),
]); ]);
let v: [Felt; DIGEST_SIZE] = digest.into(); // BY VALUE
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpoDigest = v.into(); let v2: RpoDigest = v.into();
assert_eq!(digest, v2); assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [Felt; DIGEST_SIZE] = (&digest).into(); let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpoDigest = v.into(); let v2: RpoDigest = v.into();
assert_eq!(digest, v2); assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u64; DIGEST_SIZE] = digest.into(); let v: [u64; DIGEST_SIZE] = digest.into();
let v2: RpoDigest = v.try_into().unwrap(); let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u64; DIGEST_SIZE] = (&digest).into(); let v: [Felt; DIGEST_SIZE] = digest.into();
let v2: RpoDigest = v.try_into().unwrap(); let v2: RpoDigest = v.into();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into(); let v: [u8; DIGEST_BYTES] = digest.into();
let v2: RpoDigest = v.try_into().unwrap(); let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: String = digest.into(); let v: String = digest.into();
let v2: RpoDigest = v.try_into().unwrap(); let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: String = (&digest).into(); // BY REF
let v2: RpoDigest = v.try_into().unwrap(); // ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u64; DIGEST_SIZE] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into(); let v: [Felt; DIGEST_SIZE] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap(); let v2: RpoDigest = (&v).into();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into(); let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap(); let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: String = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
} }
} }

View File

@@ -4,11 +4,11 @@ use super::{
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox, add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1, apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
}; };
mod digest; mod digest;
pub use digest::RpoDigest; pub use digest::{RpoDigest, RpoDigestError};
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@@ -19,12 +19,14 @@ mod tests;
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output. /// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
/// ///
/// The hash function is implemented according to the Rescue Prime Optimized /// The hash function is implemented according to the Rescue Prime Optimized
/// [specifications](https://eprint.iacr.org/2022/1577) /// [specifications](https://eprint.iacr.org/2022/1577) while the padding rule follows the one
/// described [here](https://eprint.iacr.org/2023/1045).
/// ///
/// The parameters used to instantiate the function are: /// The parameters used to instantiate the function are:
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1. /// * Field: 64-bit prime field with modulus p = 2^64 - 2^32 + 1.
/// * State width: 12 field elements. /// * State width: 12 field elements.
/// * Capacity size: 4 field elements. /// * Rate size: r = 8 field elements.
/// * Capacity size: c = 4 field elements.
/// * Number of founds: 7. /// * Number of founds: 7.
/// * S-Box degree: 7. /// * S-Box degree: 7.
/// ///
@@ -50,8 +52,23 @@ mod tests;
/// ///
/// Thus, if the underlying data consists of valid field elements, it might make more sense /// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using /// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes /// [hash_elements()](Rpo256::hash_elements) function rather than hashing the serialized bytes
/// using [hash()](Rpo256::hash) function. /// using [hash()](Rpo256::hash) function.
///
/// ## Domain separation
/// [merge_in_domain()](Rpo256::merge_in_domain) hashes two digests into one digest with some domain
/// identifier and the current implementation sets the second capacity element to the value of
/// this domain identifier. Using a similar argument to the one formulated for domain separation of
/// the RPX hash function in Appendix C of its [specification](https://eprint.iacr.org/2023/1045),
/// one sees that doing so degrades only pre-image resistance, from its initial bound of c.log_2(p),
/// by as much as the log_2 of the size of the domain identifier space. Since pre-image resistance
/// becomes the bottleneck for the security bound of the sponge in overwrite-mode only when it is
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
/// the size of the domain identifier space, including for padding, is less than 2^128.
///
/// ## Hashing of empty input
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
/// the benefit of requiring no calls to the RPO permutation when hashing empty input.
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpo256(); pub struct Rpo256();
@@ -65,14 +82,16 @@ impl Hasher for Rpo256 {
// initialize the state with zeroes // initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
// set the capacity (first element) to a flag on whether or not the input length is evenly // determine the number of field elements needed to encode `bytes` when each field element
// divided by the rate. this will prevent collisions between padded and non-padded inputs, // represents at most 7 bytes.
// and will rule out the need to perform an extra permutation in case of evenly divided let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
// inputs.
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0; // set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
if !is_rate_multiple { // this to achieve:
state[CAPACITY_RANGE.start] = ONE; // 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
} // 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
state[CAPACITY_RANGE.start] =
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
// initialize a buffer to receive the little-endian elements. // initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8]; let mut buf = [0_u8; 8];
@@ -81,41 +100,49 @@ impl Hasher for Rpo256 {
// into the state. // into the state.
// //
// every time the rate range is filled, a permutation is performed. if the final value of // every time the rate range is filled, a permutation is performed. if the final value of
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an // `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
// additional permutation must be performed. // and an additional permutation must be performed.
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| { let mut current_chunk_idx = 0_usize;
// the last element of the iteration may or may not be a full chunk. if it's not, then // handle the case of an empty `bytes`
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`. let last_chunk_idx = if num_field_elem == 0 {
// this will avoid collisions. current_chunk_idx
if chunk.len() == BINARY_CHUNK_SIZE { } else {
num_field_elem - 1
};
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
// copy the chunk into the buffer
if current_chunk_idx != last_chunk_idx {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk); buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else { } else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
// needed to fill it
buf.fill(0); buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk); buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1; buf[chunk.len()] = 1;
} }
current_chunk_idx += 1;
// set the current rate element to the input. since we take at most 7 bytes, we are // set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element. // guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf)); state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the // proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range. // counter to the beginning of the range.
if i == RATE_WIDTH - 1 { if rate_pos == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state); Self::apply_permutation(&mut state);
0 0
} else { } else {
i + 1 rate_pos + 1
} }
}); });
// if we absorbed some elements but didn't apply a permutation to them (would happen when // if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
// don't need to apply any extra padding because the first capacity element contains a // don't need to apply any extra padding because the first capacity element contains a
// flag indicating whether the input is evenly divisible by the rate. // flag indicating the number of field elements constituting the last block when the latter
if i != 0 { // is not divisible by `RATE_WIDTH`.
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO); if rate_pos != 0 {
state[RATE_RANGE.start + i] = ONE; state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
Self::apply_permutation(&mut state); Self::apply_permutation(&mut state);
} }
@@ -127,7 +154,7 @@ impl Hasher for Rpo256 {
// initialize the state by copying the digest elements into the rate portion of the state // initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0. // (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements(values.iter()); let it = Self::Digest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() { for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v; state[RATE_RANGE.start + i] = *v;
} }
@@ -137,29 +164,28 @@ impl Hasher for Rpo256 {
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
} }
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(Self::Digest::digests_as_elements(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows: // initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state. // - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element // - if the value fits into a single field element, copy it into the fifth rate element and
// and set the sixth rate element to 1. // set the first capacity element to 5.
// - if the value doesn't fit into a single field element, split it into two field // - if the value doesn't fit into a single field element, split it into two field elements,
// elements, copy them into rate elements 5 and 6, and set the seventh rate element // copy them into rate elements 5 and 6 and set the first capacity element to 6.
// to 1.
// - set the first capacity element to 1
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements()); state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value); state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS { if value < Felt::MODULUS {
state[INPUT2_RANGE.start + 1] = ONE; state[CAPACITY_RANGE.start] = Felt::from(5_u8);
} else { } else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS); state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[INPUT2_RANGE.start + 2] = ONE; state[CAPACITY_RANGE.start] = Felt::from(6_u8);
} }
// common padding for both cases // apply the RPO permutation and return the first four elements of the rate
state[CAPACITY_RANGE.start] = ONE;
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state); Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
} }
@@ -173,11 +199,9 @@ impl ElementHasher for Rpo256 {
let elements = E::slice_as_base_elements(elements); let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which // initialize state to all zeros, except for the first element of the capacity part, which
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH. // is set to `elements.len() % RATE_WIDTH`.
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
if elements.len() % RATE_WIDTH != 0 { state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
state[CAPACITY_RANGE.start] = ONE;
}
// absorb elements into the state one by one until the rate portion of the state is filled // absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all // up; then apply the Rescue permutation and start absorbing again; repeat until all
@@ -194,11 +218,8 @@ impl ElementHasher for Rpo256 {
// if we absorbed some elements but didn't apply a permutation to them (would happen when // if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
// padding by appending a 1 followed by as many 0 as necessary to make the input length a // padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
// multiple of the RATE_WIDTH.
if i > 0 { if i > 0 {
state[RATE_RANGE.start + i] = ONE;
i += 1;
while i != RATE_WIDTH { while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO; state[RATE_RANGE.start + i] = ZERO;
i += 1; i += 1;
@@ -273,7 +294,7 @@ impl Rpo256 {
// initialize the state by copying the digest elements into the rate portion of the state // initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0. // (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
let it = RpoDigest::digests_as_elements(values.iter()); let it = RpoDigest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() { for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v; state[RATE_RANGE.start + i] = *v;
} }

View File

@@ -1,12 +1,16 @@
use alloc::{collections::BTreeSet, vec::Vec};
use proptest::prelude::*; use proptest::prelude::*;
use rand_utils::rand_value; use rand_utils::rand_value;
use super::{ use super::{
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA}, super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO, Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, STATE_WIDTH, ZERO,
};
use crate::{
hash::rescue::{BINARY_CHUNK_SIZE, CAPACITY_RANGE, RATE_WIDTH},
Word, ONE,
}; };
use crate::Word;
use alloc::{collections::BTreeSet, vec::Vec};
#[test] #[test]
fn test_sbox() { fn test_sbox() {
@@ -58,7 +62,7 @@ fn merge_vs_merge_in_domain() {
]; ];
let merge_result = Rpo256::merge(&digests); let merge_result = Rpo256::merge(&digests);
// ------------- merge with domain = 0 ---------------------------------------------------------- // ------------- merge with domain = 0 -------------
// set domain to ZERO. This should not change the result. // set domain to ZERO. This should not change the result.
let domain = ZERO; let domain = ZERO;
@@ -66,7 +70,7 @@ fn merge_vs_merge_in_domain() {
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain); let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
assert_eq!(merge_result, merge_in_domain_result); assert_eq!(merge_result, merge_in_domain_result);
// ------------- merge with domain = 1 ---------------------------------------------------------- // ------------- merge with domain = 1 -------------
// set domain to ONE. This should change the result. // set domain to ONE. This should change the result.
let domain = ONE; let domain = ONE;
@@ -125,6 +129,27 @@ fn hash_padding() {
assert_ne!(r1, r2); assert_ne!(r1, r2);
} }
#[test]
fn hash_padding_no_extra_permutation_call() {
use crate::hash::rescue::DIGEST_RANGE;
// Implementation
let num_bytes = BINARY_CHUNK_SIZE * RATE_WIDTH;
let mut buffer = vec![0_u8; num_bytes];
*buffer.last_mut().unwrap() = 97;
let r1 = Rpo256::hash(&buffer);
// Expected
let final_chunk = [0_u8, 0, 0, 0, 0, 0, 97, 1];
let mut state = [ZERO; STATE_WIDTH];
// padding when hashing bytes
state[CAPACITY_RANGE.start] = Felt::from(RATE_WIDTH as u8);
*state.last_mut().unwrap() = Felt::new(u64::from_le_bytes(final_chunk));
Rpo256::apply_permutation(&mut state);
assert_eq!(&r1[0..4], &state[DIGEST_RANGE]);
}
#[test] #[test]
fn hash_elements_padding() { fn hash_elements_padding() {
let e1 = [Felt::new(rand_value()); 2]; let e1 = [Felt::new(rand_value()); 2];
@@ -158,6 +183,24 @@ fn hash_elements() {
assert_eq!(m_result, h_result); assert_eq!(m_result, h_result);
} }
#[test]
fn hash_empty() {
let elements: Vec<Felt> = vec![];
let zero_digest = RpoDigest::default();
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_empty_bytes() {
let bytes: Vec<u8> = vec![];
let zero_digest = RpoDigest::default();
let h_result = Rpo256::hash(&bytes);
assert_eq!(zero_digest, h_result);
}
#[test] #[test]
fn hash_test_vectors() { fn hash_test_vectors() {
let elements = [ let elements = [
@@ -228,46 +271,46 @@ proptest! {
const EXPECTED: [Word; 19] = [ const EXPECTED: [Word; 19] = [
[ [
Felt::new(1502364727743950833), Felt::new(18126731724905382595),
Felt::new(5880949717274681448), Felt::new(7388557040857728717),
Felt::new(162790463902224431), Felt::new(14290750514634285295),
Felt::new(6901340476773664264), Felt::new(7852282086160480146),
], ],
[ [
Felt::new(7478710183745780580), Felt::new(10139303045932500183),
Felt::new(3308077307559720969), Felt::new(2293916558361785533),
Felt::new(3383561985796182409), Felt::new(15496361415980502047),
Felt::new(17205078494700259815), Felt::new(17904948502382283940),
], ],
[ [
Felt::new(17439912364295172999), Felt::new(17457546260239634015),
Felt::new(17979156346142712171), Felt::new(803990662839494686),
Felt::new(8280795511427637894), Felt::new(10386005777401424878),
Felt::new(9349844417834368814), Felt::new(18168807883298448638),
], ],
[ [
Felt::new(5105868198472766874), Felt::new(13072499238647455740),
Felt::new(13090564195691924742), Felt::new(10174350003422057273),
Felt::new(1058904296915798891), Felt::new(9201651627651151113),
Felt::new(18379501748825152268), Felt::new(6872461887313298746),
], ],
[ [
Felt::new(9133662113608941286), Felt::new(2903803350580990546),
Felt::new(12096627591905525991), Felt::new(1838870750730563299),
Felt::new(14963426595993304047), Felt::new(4258619137315479708),
Felt::new(13290205840019973377), Felt::new(17334260395129062936),
], ],
[ [
Felt::new(3134262397541159485), Felt::new(8571221005243425262),
Felt::new(10106105871979362399), Felt::new(3016595589318175865),
Felt::new(138768814855329459), Felt::new(13933674291329928438),
Felt::new(15044809212457404677), Felt::new(678640375034313072),
], ],
[ [
Felt::new(162696376578462826), Felt::new(16314113978986502310),
Felt::new(4991300494838863586), Felt::new(14587622368743051587),
Felt::new(660346084748120605), Felt::new(2808708361436818462),
Felt::new(13179389528641752698), Felt::new(10660517522478329440),
], ],
[ [
Felt::new(2242391899857912644), Felt::new(2242391899857912644),
@@ -276,46 +319,46 @@ const EXPECTED: [Word; 19] = [
Felt::new(5046143039268215739), Felt::new(5046143039268215739),
], ],
[ [
Felt::new(9585630502158073976), Felt::new(5218076004221736204),
Felt::new(1310051013427303477), Felt::new(17169400568680971304),
Felt::new(7491921222636097758), Felt::new(8840075572473868990),
Felt::new(9417501558995216762), Felt::new(12382372614369863623),
], ],
[ [
Felt::new(1994394001720334744), Felt::new(9783834557155203486),
Felt::new(10866209900885216467), Felt::new(12317263104955018849),
Felt::new(13836092831163031683), Felt::new(3933748931816109604),
Felt::new(10814636682252756697), Felt::new(1843043029836917214),
], ],
[ [
Felt::new(17486854790732826405), Felt::new(14498234468286984551),
Felt::new(17376549265955727562), Felt::new(16837257669834682387),
Felt::new(2371059831956435003), Felt::new(6664141123711355107),
Felt::new(17585704935858006533), Felt::new(4590460158294697186),
], ],
[ [
Felt::new(11368277489137713825), Felt::new(4661800562479916067),
Felt::new(3906270146963049287), Felt::new(11794407552792839953),
Felt::new(10236262408213059745), Felt::new(9037742258721863712),
Felt::new(78552867005814007), Felt::new(6287820818064278819),
], ],
[ [
Felt::new(17899847381280262181), Felt::new(7752693085194633729),
Felt::new(14717912805498651446), Felt::new(7379857372245835536),
Felt::new(10769146203951775298), Felt::new(9270229380648024178),
Felt::new(2774289833490417856), Felt::new(10638301488452560378),
], ],
[ [
Felt::new(3794717687462954368), Felt::new(11542686762698783357),
Felt::new(4386865643074822822), Felt::new(15570714990728449027),
Felt::new(8854162840275334305), Felt::new(7518801014067819501),
Felt::new(7129983987107225269), Felt::new(12706437751337583515),
], ],
[ [
Felt::new(7244773535611633983), Felt::new(9553923701032839042),
Felt::new(19359923075859320), Felt::new(7281190920209838818),
Felt::new(10898655967774994333), Felt::new(2488477917448393955),
Felt::new(9319339563065736480), Felt::new(5088955350303368837),
], ],
[ [
Felt::new(4935426252518736883), Felt::new(4935426252518736883),
@@ -324,21 +367,21 @@ const EXPECTED: [Word; 19] = [
Felt::new(18159875708229758073), Felt::new(18159875708229758073),
], ],
[ [
Felt::new(14871230873837295931), Felt::new(12795429638314178838),
Felt::new(11225255908868362971), Felt::new(14360248269767567855),
Felt::new(18100987641405432308), Felt::new(3819563852436765058),
Felt::new(1559244340089644233), Felt::new(10859123583999067291),
], ],
[ [
Felt::new(8348203744950016968), Felt::new(2695742617679420093),
Felt::new(4041411241960726733), Felt::new(9151515850666059759),
Felt::new(17584743399305468057), Felt::new(15855828029180595485),
Felt::new(16836952610803537051), Felt::new(17190029785471463210),
], ],
[ [
Felt::new(16139797453633030050), Felt::new(13205273108219124830),
Felt::new(1090233424040889412), Felt::new(2524898486192849221),
Felt::new(10770255347785669036), Felt::new(14618764355375283547),
Felt::new(16982398877290254028), Felt::new(10615614265042186874),
], ],
]; ];

View File

@@ -1,5 +1,7 @@
use alloc::string::String; use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref}; use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO}; use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
use crate::{ use crate::{
@@ -19,6 +21,9 @@ use crate::{
pub struct RpxDigest([Felt; DIGEST_SIZE]); pub struct RpxDigest([Felt; DIGEST_SIZE]);
impl RpxDigest { impl RpxDigest {
/// The serialized size of the digest in bytes.
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self { pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value) Self(value)
} }
@@ -31,13 +36,19 @@ impl RpxDigest {
<Self as Digest>::as_bytes(self) <Self as Digest>::as_bytes(self)
} }
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt> pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where where
I: Iterator<Item = &'a Self>, I: Iterator<Item = &'a Self>,
{ {
digests.flat_map(|d| d.0.iter()) digests.flat_map(|d| d.0.iter())
} }
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
}
/// Returns hexadecimal representation of this digest prefixed with `0x`. /// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String { pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes()) bytes_to_hex_string(self.as_bytes())
@@ -118,26 +129,145 @@ impl Randomizable for RpxDigest {
// CONVERSIONS: FROM RPX DIGEST // CONVERSIONS: FROM RPX DIGEST
// ================================================================================================ // ================================================================================================
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] { #[derive(Debug, Error)]
fn from(value: &RpxDigest) -> Self { pub enum RpxDigestError {
value.0 #[error("failed to convert digest field element to {0}")]
TypeConversion(&'static str),
#[error("failed to convert to field element: {0}")]
InvalidFieldElement(String),
}
impl TryFrom<&RpxDigest> for [bool; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
} }
} }
impl From<RpxDigest> for [Felt; DIGEST_SIZE] { impl TryFrom<RpxDigest> for [bool; DIGEST_SIZE] {
fn from(value: RpxDigest) -> Self { type Error = RpxDigestError;
value.0
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
fn to_bool(v: u64) -> Option<bool> {
if v <= 1 {
Some(v == 1)
} else {
None
}
}
Ok([
to_bool(value.0[0].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[1].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[2].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[3].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u8; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u8; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u16; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u16; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u32; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u32; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
])
} }
} }
impl From<&RpxDigest> for [u64; DIGEST_SIZE] { impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpxDigest) -> Self { fn from(value: &RpxDigest) -> Self {
[ (*value).into()
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
} }
} }
@@ -152,6 +282,18 @@ impl From<RpxDigest> for [u64; DIGEST_SIZE] {
} }
} }
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpxDigest) -> Self {
value.0
}
}
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpxDigest) -> Self {
value.0
}
}
impl From<&RpxDigest> for [u8; DIGEST_BYTES] { impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpxDigest) -> Self { fn from(value: &RpxDigest) -> Self {
value.as_bytes() value.as_bytes()
@@ -164,13 +306,6 @@ impl From<RpxDigest> for [u8; DIGEST_BYTES] {
} }
} }
impl From<RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpxDigest) -> Self {
value.to_hex()
}
}
impl From<&RpxDigest> for String { impl From<&RpxDigest> for String {
/// The returned string starts with `0x`. /// The returned string starts with `0x`.
fn from(value: &RpxDigest) -> Self { fn from(value: &RpxDigest) -> Self {
@@ -178,13 +313,83 @@ impl From<&RpxDigest> for String {
} }
} }
impl From<RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpxDigest) -> Self {
value.to_hex()
}
}
// CONVERSIONS: TO RPX DIGEST // CONVERSIONS: TO RPX DIGEST
// ================================================================================================ // ================================================================================================
#[derive(Copy, Clone, Debug)] impl From<&[bool; DIGEST_SIZE]> for RpxDigest {
pub enum RpxDigestError { fn from(value: &[bool; DIGEST_SIZE]) -> Self {
/// The provided u64 integer does not fit in the field's moduli. (*value).into()
InvalidInteger, }
}
impl From<[bool; DIGEST_SIZE]> for RpxDigest {
fn from(value: [bool; DIGEST_SIZE]) -> Self {
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
}
}
impl From<&[u8; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u8; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u8; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u16; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u16; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u16; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u32; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u32; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u32; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
(*value).try_into()
}
}
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
Ok(Self([
value[0].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[1].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[2].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[3].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
]))
}
} }
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest { impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
@@ -199,6 +404,14 @@ impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
} }
} }
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest { impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError; type Error = HexParseError;
@@ -218,14 +431,6 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
} }
} }
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<&[u8]> for RpxDigest { impl TryFrom<&[u8]> for RpxDigest {
type Error = HexParseError; type Error = HexParseError;
@@ -234,42 +439,12 @@ impl TryFrom<&[u8]> for RpxDigest {
} }
} }
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
Ok(Self([
value[0].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
value[1].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
value[2].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
value[3].try_into().map_err(|_| RpxDigestError::InvalidInteger)?,
]))
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
(*value).try_into()
}
}
impl TryFrom<&str> for RpxDigest { impl TryFrom<&str> for RpxDigest {
type Error = HexParseError; type Error = HexParseError;
/// Expects the string to start with `0x`. /// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> { fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes(value).and_then(|v| v.try_into()) hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpxDigest::try_from)
}
}
impl TryFrom<String> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
} }
} }
@@ -282,6 +457,15 @@ impl TryFrom<&String> for RpxDigest {
} }
} }
impl TryFrom<String> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
// SERIALIZATION / DESERIALIZATION // SERIALIZATION / DESERIALIZATION
// ================================================================================================ // ================================================================================================
@@ -289,6 +473,10 @@ impl Serializable for RpxDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) { fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes()); target.write_bytes(&self.as_bytes());
} }
fn get_size_hint(&self) -> usize {
Self::SERIALIZED_SIZE
}
} }
impl Deserializable for RpxDigest { impl Deserializable for RpxDigest {
@@ -308,12 +496,24 @@ impl Deserializable for RpxDigest {
} }
} }
// ITERATORS
// ================================================================================================
impl IntoIterator for RpxDigest {
type Item = Felt;
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
// TESTS // TESTS
// ================================================================================================ // ================================================================================================
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use alloc::string::String; use alloc::string::String;
use rand_utils::rand_value; use rand_utils::rand_value;
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE}; use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
@@ -331,6 +531,7 @@ mod tests {
let mut bytes = vec![]; let mut bytes = vec![];
d1.write_into(&mut bytes); d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len()); assert_eq!(DIGEST_BYTES, bytes.len());
assert_eq!(bytes.len(), d1.get_size_hint());
let mut reader = SliceReader::new(&bytes); let mut reader = SliceReader::new(&bytes);
let d2 = RpxDigest::read_from(&mut reader).unwrap(); let d2 = RpxDigest::read_from(&mut reader).unwrap();
@@ -338,7 +539,6 @@ mod tests {
assert_eq!(d1, d2); assert_eq!(d1, d2);
} }
#[cfg(feature = "std")]
#[test] #[test]
fn digest_encoding() { fn digest_encoding() {
let digest = RpxDigest([ let digest = RpxDigest([
@@ -363,44 +563,72 @@ mod tests {
Felt::new(rand_value()), Felt::new(rand_value()),
]); ]);
let v: [Felt; DIGEST_SIZE] = digest.into(); // BY VALUE
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpxDigest = v.into(); let v2: RpxDigest = v.into();
assert_eq!(digest, v2); assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [Felt; DIGEST_SIZE] = (&digest).into(); let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpxDigest = v.into(); let v2: RpxDigest = v.into();
assert_eq!(digest, v2); assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u64; DIGEST_SIZE] = digest.into(); let v: [u64; DIGEST_SIZE] = digest.into();
let v2: RpxDigest = v.try_into().unwrap(); let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u64; DIGEST_SIZE] = (&digest).into(); let v: [Felt; DIGEST_SIZE] = digest.into();
let v2: RpxDigest = v.try_into().unwrap(); let v2: RpxDigest = v.into();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into(); let v: [u8; DIGEST_BYTES] = digest.into();
let v2: RpxDigest = v.try_into().unwrap(); let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: String = digest.into(); let v: String = digest.into();
let v2: RpxDigest = v.try_into().unwrap(); let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: String = (&digest).into(); // BY REF
let v2: RpxDigest = v.try_into().unwrap(); // ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u64; DIGEST_SIZE] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into(); let v: [Felt; DIGEST_SIZE] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap(); let v2: RpxDigest = (&v).into();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into(); let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap(); let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2); assert_eq!(digest, v2);
let v: String = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
} }
} }

View File

@@ -9,7 +9,10 @@ use super::{
}; };
mod digest; mod digest;
pub use digest::RpxDigest; pub use digest::{RpxDigest, RpxDigestError};
#[cfg(test)]
mod tests;
pub type CubicExtElement = CubeExtension<Felt>; pub type CubicExtElement = CubeExtension<Felt>;
@@ -26,8 +29,10 @@ pub type CubicExtElement = CubeExtension<Felt>;
/// * Capacity size: 4 field elements. /// * Capacity size: 4 field elements.
/// * S-Box degree: 7. /// * S-Box degree: 7.
/// * Rounds: There are 3 different types of rounds: /// * Rounds: There are 3 different types of rounds:
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`. /// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` →
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field). /// `apply_inv_sbox`.
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension
/// field).
/// - (M): `apply_mds` → `add_constants`. /// - (M): `apply_mds` → `add_constants`.
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M). /// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
/// ///
@@ -53,8 +58,23 @@ pub type CubicExtElement = CubeExtension<Felt>;
/// ///
/// Thus, if the underlying data consists of valid field elements, it might make more sense /// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using /// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes /// [hash_elements()](Rpx256::hash_elements) function rather than hashing the serialized bytes
/// using [hash()](Rpx256::hash) function. /// using [hash()](Rpx256::hash) function.
///
/// ## Domain separation
/// [merge_in_domain()](Rpx256::merge_in_domain) hashes two digests into one digest with some domain
/// identifier and the current implementation sets the second capacity element to the value of
/// this domain identifier. Using a similar argument to the one formulated for domain separation
/// in Appendix C of the [specifications](https://eprint.iacr.org/2023/1045), one sees that doing
/// so degrades only pre-image resistance, from its initial bound of c.log_2(p), by as much as
/// the log_2 of the size of the domain identifier space. Since pre-image resistance becomes
/// the bottleneck for the security bound of the sponge in overwrite-mode only when it is
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
/// the size of the domain identifier space, including for padding, is less than 2^128.
///
/// ## Hashing of empty input
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
/// the benefit of requiring no calls to the RPX permutation when hashing empty input.
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpx256(); pub struct Rpx256();
@@ -86,11 +106,18 @@ impl Hasher for Rpx256 {
// into the state. // into the state.
// //
// every time the rate range is filled, a permutation is performed. if the final value of // every time the rate range is filled, a permutation is performed. if the final value of
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an // `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
// additional permutation must be performed. // and an additional permutation must be performed.
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| { let mut current_chunk_idx = 0_usize;
// handle the case of an empty `bytes`
let last_chunk_idx = if num_field_elem == 0 {
current_chunk_idx
} else {
num_field_elem - 1
};
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
// copy the chunk into the buffer // copy the chunk into the buffer
if i != num_field_elem - 1 { if current_chunk_idx != last_chunk_idx {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk); buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else { } else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are // on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
@@ -99,18 +126,19 @@ impl Hasher for Rpx256 {
buf[..chunk.len()].copy_from_slice(chunk); buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1; buf[chunk.len()] = 1;
} }
current_chunk_idx += 1;
// set the current rate element to the input. since we take at most 7 bytes, we are // set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element. // guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf)); state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the // proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range. // counter to the beginning of the range.
if i == RATE_WIDTH - 1 { if rate_pos == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state); Self::apply_permutation(&mut state);
0 0
} else { } else {
i + 1 rate_pos + 1
} }
}); });
@@ -119,8 +147,8 @@ impl Hasher for Rpx256 {
// don't need to apply any extra padding because the first capacity element contains a // don't need to apply any extra padding because the first capacity element contains a
// flag indicating the number of field elements constituting the last block when the latter // flag indicating the number of field elements constituting the last block when the latter
// is not divisible by `RATE_WIDTH`. // is not divisible by `RATE_WIDTH`.
if i != 0 { if rate_pos != 0 {
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO); state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
Self::apply_permutation(&mut state); Self::apply_permutation(&mut state);
} }
@@ -132,7 +160,7 @@ impl Hasher for Rpx256 {
// initialize the state by copying the digest elements into the rate portion of the state // initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0. // (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements(values.iter()); let it = Self::Digest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() { for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v; state[RATE_RANGE.start + i] = *v;
} }
@@ -142,13 +170,17 @@ impl Hasher for Rpx256 {
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
} }
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(Self::Digest::digests_as_elements(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows: // initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state. // - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element and // - if the value fits into a single field element, copy it into the fifth rate element and
// set the first capacity element to 5. // set the first capacity element to 5.
// - if the value doesn't fit into a single field element, split it into two field // - if the value doesn't fit into a single field element, split it into two field elements,
// elements, copy them into rate elements 5 and 6 and set the first capacity element to 6. // copy them into rate elements 5 and 6 and set the first capacity element to 6.
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements()); state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value); state[INPUT2_RANGE.start] = Felt::new(value);
@@ -159,7 +191,7 @@ impl Hasher for Rpx256 {
state[CAPACITY_RANGE.start] = Felt::from(6_u8); state[CAPACITY_RANGE.start] = Felt::from(6_u8);
} }
// apply the RPX permutation and return the first four elements of the state // apply the RPX permutation and return the first four elements of the rate
Self::apply_permutation(&mut state); Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
} }
@@ -265,7 +297,7 @@ impl Rpx256 {
// initialize the state by copying the digest elements into the rate portion of the state // initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0. // (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH]; let mut state = [ZERO; STATE_WIDTH];
let it = RpxDigest::digests_as_elements(values.iter()); let it = RpxDigest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() { for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v; state[RATE_RANGE.start + i] = *v;
} }

View File

@@ -0,0 +1,186 @@
use alloc::{collections::BTreeSet, vec::Vec};
use proptest::prelude::*;
use rand_utils::rand_value;
use super::{Felt, Hasher, Rpx256, StarkField, ZERO};
use crate::{hash::rescue::RpxDigest, ONE};
#[test]
fn hash_elements_vs_merge() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..].try_into().unwrap()),
];
let m_result = Rpx256::merge(&digests);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn merge_vs_merge_in_domain() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..].try_into().unwrap()),
];
let merge_result = Rpx256::merge(&digests);
// ----- merge with domain = 0 ----------------------------------------------------------------
// set domain to ZERO. This should not change the result.
let domain = ZERO;
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
assert_eq!(merge_result, merge_in_domain_result);
// ----- merge with domain = 1 ----------------------------------------------------------------
// set domain to ONE. This should change the result.
let domain = ONE;
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
assert_ne!(merge_result, merge_in_domain_result);
}
#[test]
fn hash_elements_vs_merge_with_int() {
let tmp = [Felt::new(rand_value()); 4];
let seed = RpxDigest::new(tmp);
// ----- value fits into a field element ------------------------------------------------------
let val: Felt = Felt::new(rand_value());
let m_result = Rpx256::merge_with_int(seed, val.as_int());
let mut elements = seed.as_elements().to_vec();
elements.push(val);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
// ----- value does not fit into a field element ----------------------------------------------
let val = Felt::MODULUS + 2;
let m_result = Rpx256::merge_with_int(seed, val);
let mut elements = seed.as_elements().to_vec();
elements.push(Felt::new(val));
elements.push(ONE);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_padding() {
// adding a zero bytes at the end of a byte string should result in a different hash
let r1 = Rpx256::hash(&[1_u8, 2, 3]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 0]);
assert_ne!(r1, r2);
// same as above but with bigger inputs
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
assert_ne!(r1, r2);
// same as above but with input splitting over two elements
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
assert_ne!(r1, r2);
// same as above but with multiple zeros
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements_padding() {
let e1 = [Felt::new(rand_value()); 2];
let e2 = [e1[0], e1[1], ZERO];
let r1 = Rpx256::hash_elements(&e1);
let r2 = Rpx256::hash_elements(&e2);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements() {
let elements = [
ZERO,
ONE,
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::new(7),
];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..8].try_into().unwrap()),
];
let m_result = Rpx256::merge(&digests);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_empty() {
let elements: Vec<Felt> = vec![];
let zero_digest = RpxDigest::default();
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_empty_bytes() {
let bytes: Vec<u8> = vec![];
let zero_digest = RpxDigest::default();
let h_result = Rpx256::hash(&bytes);
assert_eq!(zero_digest, h_result);
}
#[test]
fn sponge_bytes_with_remainder_length_wont_panic() {
// this test targets to assert that no panic will happen with the edge case of having an inputs
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
// size.
//
// this is a preliminary test to the fuzzy-stress of proptest.
Rpx256::hash(&[0; 113]);
}
#[test]
fn sponge_collision_for_wrapped_field_element() {
let a = Rpx256::hash(&[0; 8]);
let b = Rpx256::hash(&Felt::MODULUS.to_le_bytes());
assert_ne!(a, b);
}
#[test]
fn sponge_zeroes_collision() {
let mut zeroes = Vec::with_capacity(255);
let mut set = BTreeSet::new();
(0..255).for_each(|_| {
let hash = Rpx256::hash(&zeroes);
zeroes.push(0);
// panic if a collision was found
assert!(set.insert(hash));
});
}
proptest! {
#[test]
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
Rpx256::hash(bytes);
}
}

View File

@@ -4,8 +4,9 @@ use clap::Parser;
use miden_crypto::{ use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest}, hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt}, merkle::{MerkleError, Smt},
Felt, Word, ONE, Felt, Word, EMPTY_WORD, ONE,
}; };
use rand::{prelude::IteratorRandom, thread_rng, Rng};
use rand_utils::rand_value; use rand_utils::rand_value;
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@@ -13,7 +14,7 @@ use rand_utils::rand_value;
pub struct BenchmarkCmd { pub struct BenchmarkCmd {
/// Size of the tree /// Size of the tree
#[clap(short = 's', long = "size")] #[clap(short = 's', long = "size")]
size: u64, size: usize,
} }
fn main() { fn main() {
@@ -29,82 +30,184 @@ pub fn benchmark_smt() {
let mut entries = Vec::new(); let mut entries = Vec::new();
for i in 0..tree_size { for i in 0..tree_size {
let key = rand_value::<RpoDigest>(); let key = rand_value::<RpoDigest>();
let value = [ONE, ONE, ONE, Felt::new(i)]; let value = [ONE, ONE, ONE, Felt::new(i as u64)];
entries.push((key, value)); entries.push((key, value));
} }
let mut tree = construction(entries, tree_size).unwrap(); let mut tree = construction(entries.clone(), tree_size).unwrap();
insertion(&mut tree, tree_size).unwrap(); insertion(&mut tree).unwrap();
proof_generation(&mut tree, tree_size).unwrap(); batched_insertion(&mut tree).unwrap();
batched_update(&mut tree, entries).unwrap();
proof_generation(&mut tree).unwrap();
} }
/// Runs the construction benchmark for [`Smt`], returning the constructed tree. /// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> { pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt, MerkleError> {
println!("Running a construction benchmark:"); println!("Running a construction benchmark:");
let now = Instant::now(); let now = Instant::now();
let tree = Smt::with_entries(entries)?; let tree = Smt::with_entries(entries)?;
let elapsed = now.elapsed(); let elapsed = now.elapsed().as_secs_f32();
println!(
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
size,
elapsed.as_secs_f32(),
);
println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds");
println!("Number of leaf nodes: {}\n", tree.leaves().count()); println!("Number of leaf nodes: {}\n", tree.leaves().count());
Ok(tree) Ok(tree)
} }
/// Runs the insertion benchmark for the [`Smt`]. /// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;
println!("Running an insertion benchmark:"); println!("Running an insertion benchmark:");
let size = tree.num_leaves();
let mut insertion_times = Vec::new(); let mut insertion_times = Vec::new();
for i in 0..20 { for i in 0..NUM_INSERTIONS {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes()); let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new(size + i)]; let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
let now = Instant::now(); let now = Instant::now();
tree.insert(test_key, test_value); tree.insert(test_key, test_value);
let elapsed = now.elapsed(); let elapsed = now.elapsed();
insertion_times.push(elapsed.as_secs_f32()); insertion_times.push(elapsed.as_micros());
} }
println!( println!(
"An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n", "An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n",
size, // calculate the average
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by insertion_times.iter().sum::<u128>() as f64 / (NUM_INSERTIONS as f64),
// 1000. As a result, we can only multiply by 50
insertion_times.iter().sum::<f32>() * 50f32,
); );
Ok(()) Ok(())
} }
pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_INSERTIONS: usize = 1_000;
println!("Running a batched insertion benchmark:");
let size = tree.num_leaves();
let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
(key, value)
})
.collect();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
);
println!(
"An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
);
println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);
println!();
Ok(())
}
pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> {
const NUM_UPDATES: usize = 1_000;
const REMOVAL_PROBABILITY: f64 = 0.2;
println!("Running a batched update benchmark:");
let size = tree.num_leaves();
let mut rng = thread_rng();
let new_pairs =
entries
.into_iter()
.choose_multiple(&mut rng, NUM_UPDATES)
.into_iter()
.map(|(key, _)| {
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
EMPTY_WORD
} else {
[ONE, ONE, ONE, Felt::new(rng.gen())]
};
(key, value)
});
assert_eq!(new_pairs.len(), NUM_UPDATES);
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
let now = Instant::now();
tree.apply_mutations(mutations)?;
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
println!(
"An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
compute_elapsed,
compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
);
println!(
"An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
apply_elapsed,
apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
);
println!(
"An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
(compute_elapsed + apply_elapsed),
);
println!();
Ok(())
}
/// Runs the proof generation benchmark for the [`Smt`]. /// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> {
const NUM_PROOFS: usize = 100;
println!("Running a proof generation benchmark:"); println!("Running a proof generation benchmark:");
let mut insertion_times = Vec::new(); let mut insertion_times = Vec::new();
for i in 0..20 { let size = tree.num_leaves();
for i in 0..NUM_PROOFS {
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes()); let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let test_value = [ONE, ONE, ONE, Felt::new(size + i)]; let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
tree.insert(test_key, test_value); tree.insert(test_key, test_value);
let now = Instant::now(); let now = Instant::now();
let _proof = tree.open(&test_key); let _proof = tree.open(&test_key);
let elapsed = now.elapsed(); insertion_times.push(now.elapsed().as_micros());
insertion_times.push(elapsed.as_secs_f32());
} }
println!( println!(
"An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds", "An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
size, // calculate the average
// calculate the average by dividing by 20 and convert to microseconds by multiplying by insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
// 1000000. As a result, we can only multiply by 50000
insertion_times.iter().sum::<f32>() * 50000f32,
); );
Ok(()) Ok(())

View File

@@ -1,6 +1,6 @@
use core::slice; use core::slice;
use super::{Felt, RpoDigest, EMPTY_WORD}; use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD};
// EMPTY NODES SUBTREES // EMPTY NODES SUBTREES
// ================================================================================================ // ================================================================================================
@@ -25,6 +25,17 @@ impl EmptySubtreeRoots {
let pos = 255 - tree_depth + node_depth; let pos = 255 - tree_depth + node_depth;
&EMPTY_SUBTREES[pos as usize] &EMPTY_SUBTREES[pos as usize]
} }
/// Returns a sparse Merkle tree [`InnerNode`] with two empty children.
///
/// # Note
/// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth`
/// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for
/// subtrees of depth `node_depth + 1`.
pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode {
let &child = Self::entry(tree_depth, node_depth + 1);
InnerNode { left: child, right: child }
}
} }
const EMPTY_SUBTREES: [RpoDigest; 256] = [ const EMPTY_SUBTREES: [RpoDigest; 256] = [

View File

@@ -1,65 +1,34 @@
use alloc::vec::Vec; use thiserror::Error;
use core::fmt;
use super::{smt::SmtLeafError, MerklePath, NodeIndex, RpoDigest}; use super::{NodeIndex, RpoDigest};
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Debug, Error)]
pub enum MerkleError { pub enum MerkleError {
ConflictingRoots(Vec<RpoDigest>), #[error("expected merkle root {expected_root} found {actual_root}")]
ConflictingRoots {
expected_root: RpoDigest,
actual_root: RpoDigest,
},
#[error("provided merkle tree depth {0} is too small")]
DepthTooSmall(u8), DepthTooSmall(u8),
#[error("provided merkle tree depth {0} is too big")]
DepthTooBig(u64), DepthTooBig(u64),
#[error("multiple values provided for merkle tree index {0}")]
DuplicateValuesForIndex(u64), DuplicateValuesForIndex(u64),
DuplicateValuesForKey(RpoDigest), #[error("node index value {value} is not valid for depth {depth}")]
InvalidIndex { depth: u8, value: u64 }, InvalidNodeIndex { depth: u8, value: u64 },
InvalidDepth { expected: u8, provided: u8 }, #[error("provided node index depth {provided} does not match expected depth {expected}")]
InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 }, InvalidNodeIndexDepth { expected: u8, provided: u8 },
InvalidPath(MerklePath), #[error("merkle subtree depth {subtree_depth} exceeds merkle tree depth {tree_depth}")]
InvalidNumEntries(usize), SubtreeDepthExceedsDepth { subtree_depth: u8, tree_depth: u8 },
NodeNotInSet(NodeIndex), #[error("number of entries in the merkle tree exceeds the maximum of {0}")]
NodeNotInStore(RpoDigest, NodeIndex), TooManyEntries(usize),
#[error("node index `{0}` not found in the tree")]
NodeIndexNotFoundInTree(NodeIndex),
#[error("node {0:?} with index `{1}` not found in the store")]
NodeIndexNotFoundInStore(RpoDigest, NodeIndex),
#[error("number of provided merkle tree leaves {0} is not a power of two")]
NumLeavesNotPowerOfTwo(usize), NumLeavesNotPowerOfTwo(usize),
#[error("root {0:?} is not in the store")]
RootNotInStore(RpoDigest), RootNotInStore(RpoDigest),
SmtLeaf(SmtLeafError),
}
impl fmt::Display for MerkleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use MerkleError::*;
match self {
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
InvalidIndex { depth, value } => {
write!(f, "the index value {value} is not valid for the depth {depth}")
}
InvalidDepth { expected, provided } => {
write!(f, "the provided depth {provided} is not valid for {expected}")
}
InvalidSubtreeDepth { subtree_depth, tree_depth } => {
write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}")
}
InvalidPath(_path) => write!(f, "the provided path is not valid"),
InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"),
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
NodeNotInStore(hash, index) => {
write!(f, "the node {hash:?} with index ({index}) is not in the store")
}
NumLeavesNotPowerOfTwo(leaves) => {
write!(f, "the leaves count {leaves} is not a power of 2")
}
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
SmtLeaf(smt_leaf_error) => write!(f, "smt leaf error: {smt_leaf_error}"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for MerkleError {}
impl From<SmtLeafError> for MerkleError {
fn from(value: SmtLeafError) -> Self {
Self::SmtLeaf(value)
}
} }

View File

@@ -38,7 +38,7 @@ impl NodeIndex {
/// Returns an error if the `value` is greater than or equal to 2^{depth}. /// Returns an error if the `value` is greater than or equal to 2^{depth}.
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> { pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
if (64 - value.leading_zeros()) > depth as u32 { if (64 - value.leading_zeros()) > depth as u32 {
Err(MerkleError::InvalidIndex { depth, value }) Err(MerkleError::InvalidNodeIndex { depth, value })
} else { } else {
Ok(Self { depth, value }) Ok(Self { depth, value })
} }
@@ -128,7 +128,7 @@ impl NodeIndex {
self.value self.value
} }
/// Returns true if the current instance points to a right sibling node. /// Returns `true` if the current instance points to a right sibling node.
pub const fn is_value_odd(&self) -> bool { pub const fn is_value_odd(&self) -> bool {
(self.value & 1) == 1 (self.value & 1) == 1
} }
@@ -182,6 +182,7 @@ impl Deserializable for NodeIndex {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use assert_matches::assert_matches;
use proptest::prelude::*; use proptest::prelude::*;
use super::*; use super::*;
@@ -190,19 +191,19 @@ mod tests {
fn test_node_index_value_too_high() { fn test_node_index_value_too_high() {
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 }); assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
let err = NodeIndex::new(0, 1).unwrap_err(); let err = NodeIndex::new(0, 1).unwrap_err();
assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 }); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 }); assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
let err = NodeIndex::new(1, 2).unwrap_err(); let err = NodeIndex::new(1, 2).unwrap_err();
assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 }); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 }); assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
let err = NodeIndex::new(2, 4).unwrap_err(); let err = NodeIndex::new(2, 4).unwrap_err();
assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 }); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 }); assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
let err = NodeIndex::new(3, 8).unwrap_err(); let err = NodeIndex::new(3, 8).unwrap_err();
assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 }); assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
} }
#[test] #[test]

View File

@@ -1,8 +1,6 @@
use alloc::{string::String, vec::Vec}; use alloc::{string::String, vec::Vec};
use core::{fmt, ops::Deref, slice}; use core::{fmt, ops::Deref, slice};
use winter_math::log2;
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word}; use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
use crate::utils::{uninit_vector, word_to_hex}; use crate::utils::{uninit_vector, word_to_hex};
@@ -70,7 +68,7 @@ impl MerkleTree {
/// ///
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc. /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
pub fn depth(&self) -> u8 { pub fn depth(&self) -> u8 {
log2(self.nodes.len() / 2) as u8 (self.nodes.len() / 2).ilog2() as u8
} }
/// Returns a node at the specified depth and index value. /// Returns a node at the specified depth and index value.
@@ -213,7 +211,7 @@ pub struct InnerNodeIterator<'a> {
index: usize, index: usize,
} }
impl<'a> Iterator for InnerNodeIterator<'a> { impl Iterator for InnerNodeIterator<'_> {
type Item = InnerNodeInfo; type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {

View File

@@ -1,6 +1,7 @@
use super::super::RpoDigest;
use alloc::vec::Vec; use alloc::vec::Vec;
use super::super::RpoDigest;
/// Container for the update data of a [super::PartialMmr] /// Container for the update data of a [super::PartialMmr]
#[derive(Debug)] #[derive(Debug)]
pub struct MmrDelta { pub struct MmrDelta {

View File

@@ -1,35 +1,27 @@
use core::fmt::{Display, Formatter}; use alloc::string::String;
#[cfg(feature = "std")]
use std::error::Error; use thiserror::Error;
use crate::merkle::MerkleError; use crate::merkle::MerkleError;
#[derive(Debug, PartialEq, Eq, Clone)] #[derive(Debug, Error)]
pub enum MmrError { pub enum MmrError {
InvalidPosition(usize), #[error("mmr does not contain position {0}")]
InvalidPeaks, PositionNotFound(usize),
InvalidPeak, #[error("mmr peaks are invalid: {0}")]
InvalidPeaks(String),
#[error(
"mmr peak does not match the computed merkle root of the provided authentication path"
)]
PeakPathMismatch,
#[error("requested peak index is {peak_idx} but the number of peaks is {peaks_len}")]
PeakOutOfBounds { peak_idx: usize, peaks_len: usize },
#[error("invalid mmr update")]
InvalidUpdate, InvalidUpdate,
UnknownPeak, #[error("mmr does not contain a peak with depth {0}")]
MerkleError(MerkleError), UnknownPeak(u8),
#[error("invalid merkle path")]
InvalidMerklePath(#[source] MerkleError),
#[error("merkle root computation failed")]
MerkleRootComputationFailed(#[source] MerkleError),
} }
impl Display for MmrError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match self {
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
MmrError::InvalidPeaks => write!(fmt, "Invalid peaks count"),
MmrError::InvalidPeak => {
write!(fmt, "Peak values does not match merkle path computed root")
}
MmrError::InvalidUpdate => write!(fmt, "Invalid mmr update"),
MmrError::UnknownPeak => {
write!(fmt, "Peak not in Mmr")
}
MmrError::MerkleError(err) => write!(fmt, "{}", err),
}
}
}
#[cfg(feature = "std")]
impl Error for MmrError {}

View File

@@ -7,16 +7,17 @@
//! //!
//! Additionally the structure only supports adding leaves to the right-most tree, the one with the //! Additionally the structure only supports adding leaves to the right-most tree, the one with the
//! least number of leaves. The structure preserves the invariant that each tree has different //! least number of leaves. The structure preserves the invariant that each tree has different
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are //! depths, i.e. as part of adding a new element to the forest the trees with same depth are
//! merged, creating a new tree with depth d+1, this process is continued until the property is //! merged, creating a new tree with depth d+1, this process is continued until the property is
//! reestablished. //! reestablished.
use alloc::vec::Vec;
use super::{ use super::{
super::{InnerNodeInfo, MerklePath}, super::{InnerNodeInfo, MerklePath},
bit::TrueBitPositionIterator, bit::TrueBitPositionIterator,
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256, leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
RpoDigest, RpoDigest,
}; };
use alloc::vec::Vec;
// MMR // MMR
// =============================================================================================== // ===============================================================================================
@@ -72,19 +73,36 @@ impl Mmr {
// FUNCTIONALITY // FUNCTIONALITY
// ============================================================================================ // ============================================================================================
/// Given a leaf position, returns the Merkle path to its corresponding peak. If the position /// Returns an [MmrProof] for the leaf at the specified position.
/// is greater-or-equal than the tree size an error is returned.
/// ///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were /// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element /// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on. /// has position 0, the second position 1, and so on.
pub fn open(&self, pos: usize, target_forest: usize) -> Result<MmrProof, MmrError> { ///
/// # Errors
/// Returns an error if the specified leaf position is out of bounds for this MMR.
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
self.open_at(pos, self.forest)
}
/// Returns an [MmrProof] for the leaf at the specified position using the state of the MMR
/// at the specified `forest`.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if:
/// - The specified leaf position is out of bounds for this MMR.
/// - The specified `forest` value is not valid for this MMR.
pub fn open_at(&self, pos: usize, forest: usize) -> Result<MmrProof, MmrError> {
// find the target tree responsible for the MMR position // find the target tree responsible for the MMR position
let tree_bit = let tree_bit =
leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?; leaf_to_corresponding_tree(pos, forest).ok_or(MmrError::PositionNotFound(pos))?;
// isolate the trees before the target // isolate the trees before the target
let forest_before = target_forest & high_bitmask(tree_bit + 1); let forest_before = forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before); let index_offset = nodes_in_forest(forest_before);
// update the value position from global to the target tree // update the value position from global to the target tree
@@ -94,7 +112,7 @@ impl Mmr {
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset); let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
Ok(MmrProof { Ok(MmrProof {
forest: target_forest, forest,
position: pos, position: pos,
merkle_path: MerklePath::new(path), merkle_path: MerklePath::new(path),
}) })
@@ -108,7 +126,7 @@ impl Mmr {
pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> { pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> {
// find the target tree responsible for the MMR position // find the target tree responsible for the MMR position
let tree_bit = let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?; leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
// isolate the trees before the target // isolate the trees before the target
let forest_before = self.forest & high_bitmask(tree_bit + 1); let forest_before = self.forest & high_bitmask(tree_bit + 1);
@@ -145,10 +163,21 @@ impl Mmr {
self.forest += 1; self.forest += 1;
} }
/// Returns an peaks of the MMR for the version specified by `forest`. /// Returns the current peaks of the MMR.
pub fn peaks(&self, forest: usize) -> Result<MmrPeaks, MmrError> { pub fn peaks(&self) -> MmrPeaks {
self.peaks_at(self.forest).expect("failed to get peaks at current forest")
}
/// Returns the peaks of the MMR at the state specified by `forest`.
///
/// # Errors
/// Returns an error if the specified `forest` value is not valid for this MMR.
pub fn peaks_at(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
if forest > self.forest { if forest > self.forest {
return Err(MmrError::InvalidPeaks); return Err(MmrError::InvalidPeaks(format!(
"requested forest {forest} exceeds current forest {}",
self.forest
)));
} }
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest) let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
@@ -173,7 +202,7 @@ impl Mmr {
/// that have been merged together, followed by the new peaks of the [Mmr]. /// that have been merged together, followed by the new peaks of the [Mmr].
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> { pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
if to_forest > self.forest || from_forest > to_forest { if to_forest > self.forest || from_forest > to_forest {
return Err(MmrError::InvalidPeaks); return Err(MmrError::InvalidPeaks(format!("to_forest {to_forest} exceeds the current forest {} or from_forest {from_forest} exceeds to_forest", self.forest)));
} }
if from_forest == to_forest { if from_forest == to_forest {
@@ -344,7 +373,7 @@ pub struct MmrNodes<'a> {
index: usize, index: usize,
} }
impl<'a> Iterator for MmrNodes<'a> { impl Iterator for MmrNodes<'_> {
type Item = InnerNodeInfo; type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
@@ -377,7 +406,8 @@ impl<'a> Iterator for MmrNodes<'a> {
// the next parent position is one above the position of the pair // the next parent position is one above the position of the pair
let parent = self.last_right << 1; let parent = self.last_right << 1;
// the left node has been paired and the current parent yielded, removed it from the forest // the left node has been paired and the current parent yielded, removed it from the
// forest
self.forest ^= self.last_right; self.forest ^= self.last_right;
if self.forest & parent == 0 { if self.forest & parent == 0 {
// this iteration yielded the left parent node // this iteration yielded the left parent node

View File

@@ -6,6 +6,8 @@
//! leaves count. //! leaves count.
use core::num::NonZeroUsize; use core::num::NonZeroUsize;
use winter_utils::{Deserializable, Serializable};
// IN-ORDER INDEX // IN-ORDER INDEX
// ================================================================================================ // ================================================================================================
@@ -112,6 +114,21 @@ impl InOrderIndex {
} }
} }
impl Serializable for InOrderIndex {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
target.write_usize(self.idx);
}
}
impl Deserializable for InOrderIndex {
fn read_from<R: winter_utils::ByteReader>(
source: &mut R,
) -> Result<Self, winter_utils::DeserializationError> {
let idx = source.read_usize()?;
Ok(InOrderIndex { idx })
}
}
// CONVERSIONS FROM IN-ORDER INDEX // CONVERSIONS FROM IN-ORDER INDEX
// ------------------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------------------
@@ -127,6 +144,7 @@ impl From<InOrderIndex> for u64 {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use proptest::prelude::*; use proptest::prelude::*;
use winter_utils::{Deserializable, Serializable};
use super::InOrderIndex; use super::InOrderIndex;
@@ -162,4 +180,12 @@ mod test {
assert_eq!(left.sibling(), right); assert_eq!(left.sibling(), right);
assert_eq!(left, right.sibling()); assert_eq!(left, right.sibling());
} }
#[test]
fn test_inorder_index_serialization() {
let index = InOrderIndex::from_leaf_pos(5);
let bytes = index.to_bytes();
let index2 = InOrderIndex::read_from_bytes(&bytes).unwrap();
assert_eq!(index, index2);
}
} }

View File

@@ -10,8 +10,6 @@ mod proof;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
use super::{Felt, Rpo256, RpoDigest, Word};
// REEXPORTS // REEXPORTS
// ================================================================================================ // ================================================================================================
pub use delta::MmrDelta; pub use delta::MmrDelta;
@@ -22,6 +20,8 @@ pub use partial::PartialMmr;
pub use peaks::MmrPeaks; pub use peaks::MmrPeaks;
pub use proof::MmrProof; pub use proof::MmrProof;
use super::{Felt, Rpo256, RpoDigest, Word};
// UTILITIES // UTILITIES
// =============================================================================================== // ===============================================================================================
@@ -42,8 +42,8 @@ const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0` // - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where // is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
// `k_1` is the second highest bit, so on. // `k_1` is the second highest bit, so on.
// - this means the highest bits work as a category marker, and the position is owned by // - this means the highest bits work as a category marker, and the position is owned by the
// the first tree which doesn't share a high bit with the position // first tree which doesn't share a high bit with the position
let before = forest & pos; let before = forest & pos;
let after = forest ^ before; let after = forest ^ before;
let tree = after.ilog2(); let tree = after.ilog2();

View File

@@ -1,12 +1,15 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use winter_utils::{Deserializable, Serializable};
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest}; use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
use crate::merkle::{ use crate::merkle::{
mmr::{leaf_to_corresponding_tree, nodes_in_forest}, mmr::{leaf_to_corresponding_tree, nodes_in_forest},
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks, InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
}; };
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
// TYPE ALIASES // TYPE ALIASES
// ================================================================================================ // ================================================================================================
@@ -142,7 +145,7 @@ impl PartialMmr {
/// in the underlying MMR. /// in the underlying MMR.
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> { pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
let tree_bit = let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?; leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
let depth = tree_bit as usize; let depth = tree_bit as usize;
let mut nodes = Vec::with_capacity(depth); let mut nodes = Vec::with_capacity(depth);
@@ -183,7 +186,7 @@ impl PartialMmr {
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>( pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
&'a self, &'a self,
mut leaves: I, mut leaves: I,
) -> impl Iterator<Item = InnerNodeInfo> + '_ { ) -> impl Iterator<Item = InnerNodeInfo> + 'a {
let stack = if let Some((pos, leaf)) = leaves.next() { let stack = if let Some((pos, leaf)) = leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos); let idx = InOrderIndex::from_leaf_pos(pos);
vec![(idx, leaf)] vec![(idx, leaf)]
@@ -295,12 +298,12 @@ impl PartialMmr {
// invalid. // invalid.
let tree = 1 << path.depth(); let tree = 1 << path.depth();
if tree & self.forest == 0 { if tree & self.forest == 0 {
return Err(MmrError::UnknownPeak); return Err(MmrError::UnknownPeak(path.depth()));
}; };
if leaf_pos + 1 == self.forest if leaf_pos + 1 == self.forest
&& path.depth() == 0 && path.depth() == 0
&& self.peaks.last().map_or(false, |v| *v == leaf) && self.peaks.last().is_some_and(|v| *v == leaf)
{ {
self.track_latest = true; self.track_latest = true;
return Ok(()); return Ok(());
@@ -316,9 +319,11 @@ impl PartialMmr {
// Compute the root of the authentication path, and check it matches the current version of // Compute the root of the authentication path, and check it matches the current version of
// the PartialMmr. // the PartialMmr.
let computed = path.compute_root(path_idx as u64, leaf).map_err(MmrError::MerkleError)?; let computed = path
.compute_root(path_idx as u64, leaf)
.map_err(MmrError::MerkleRootComputationFailed)?;
if self.peaks[peak_pos] != computed { if self.peaks[peak_pos] != computed {
return Err(MmrError::InvalidPeak); return Err(MmrError::PeakPathMismatch);
} }
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos); let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
@@ -353,7 +358,10 @@ impl PartialMmr {
/// inserted into the partial MMR. /// inserted into the partial MMR.
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> { pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
if delta.forest < self.forest { if delta.forest < self.forest {
return Err(MmrError::InvalidPeaks); return Err(MmrError::InvalidPeaks(format!(
"forest of mmr delta {} is less than current forest {}",
delta.forest, self.forest
)));
} }
let mut inserted_nodes = Vec::new(); let mut inserted_nodes = Vec::new();
@@ -536,7 +544,7 @@ pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
seen_nodes: BTreeSet<InOrderIndex>, seen_nodes: BTreeSet<InOrderIndex>,
} }
impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'a, I> { impl<I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'_, I> {
type Item = InnerNodeInfo; type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
@@ -572,6 +580,28 @@ impl<'a, I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<
} }
} }
impl Serializable for PartialMmr {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.forest.write_into(target);
self.peaks.write_into(target);
self.nodes.write_into(target);
target.write_bool(self.track_latest);
}
}
impl Deserializable for PartialMmr {
fn read_from<R: winter_utils::ByteReader>(
source: &mut R,
) -> Result<Self, winter_utils::DeserializationError> {
let forest = usize::read_from(source)?;
let peaks = Vec::<RpoDigest>::read_from(source)?;
let nodes = NodeMap::read_from(source)?;
let track_latest = source.read_bool()?;
Ok(Self { forest, peaks, nodes, track_latest })
}
}
// UTILS // UTILS
// ================================================================================================ // ================================================================================================
@@ -613,12 +643,15 @@ fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use alloc::{collections::BTreeSet, vec::Vec};
use winter_utils::{Deserializable, Serializable};
use super::{ use super::{
forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr, forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr,
RpoDigest, RpoDigest,
}; };
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex}; use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
use alloc::{collections::BTreeSet, vec::Vec};
const LEAVES: [RpoDigest; 7] = [ const LEAVES: [RpoDigest; 7] = [
int_to_node(0), int_to_node(0),
@@ -688,18 +721,18 @@ mod tests {
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it // build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
let mut mmr = Mmr::default(); let mut mmr = Mmr::default();
(0..10).for_each(|i| mmr.add(int_to_node(i))); (0..10).for_each(|i| mmr.add(int_to_node(i)));
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); let mut partial_mmr: PartialMmr = mmr.peaks().into();
// add authentication path for position 1 and 8 // add authentication path for position 1 and 8
{ {
let node = mmr.get(1).unwrap(); let node = mmr.get(1).unwrap();
let proof = mmr.open(1, mmr.forest()).unwrap(); let proof = mmr.open(1).unwrap();
partial_mmr.track(1, node, &proof.merkle_path).unwrap(); partial_mmr.track(1, node, &proof.merkle_path).unwrap();
} }
{ {
let node = mmr.get(8).unwrap(); let node = mmr.get(8).unwrap();
let proof = mmr.open(8, mmr.forest()).unwrap(); let proof = mmr.open(8).unwrap();
partial_mmr.track(8, node, &proof.merkle_path).unwrap(); partial_mmr.track(8, node, &proof.merkle_path).unwrap();
} }
@@ -712,7 +745,7 @@ mod tests {
validate_apply_delta(&mmr, &mut partial_mmr); validate_apply_delta(&mmr, &mut partial_mmr);
{ {
let node = mmr.get(12).unwrap(); let node = mmr.get(12).unwrap();
let proof = mmr.open(12, mmr.forest()).unwrap(); let proof = mmr.open(12).unwrap();
partial_mmr.track(12, node, &proof.merkle_path).unwrap(); partial_mmr.track(12, node, &proof.merkle_path).unwrap();
assert!(partial_mmr.track_latest); assert!(partial_mmr.track_latest);
} }
@@ -737,7 +770,7 @@ mod tests {
let nodes_delta = partial.apply(delta).unwrap(); let nodes_delta = partial.apply(delta).unwrap();
// new peaks were computed correctly // new peaks were computed correctly
assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks()); assert_eq!(mmr.peaks(), partial.peaks());
let mut expected_nodes = nodes_before; let mut expected_nodes = nodes_before;
for (key, value) in nodes_delta { for (key, value) in nodes_delta {
@@ -753,7 +786,7 @@ mod tests {
let index_value: u64 = index.into(); let index_value: u64 = index.into();
let pos = index_value / 2; let pos = index_value / 2;
let proof1 = partial.open(pos as usize).unwrap().unwrap(); let proof1 = partial.open(pos as usize).unwrap().unwrap();
let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap(); let proof2 = mmr.open(pos as usize).unwrap();
assert_eq!(proof1, proof2); assert_eq!(proof1, proof2);
} }
} }
@@ -762,16 +795,16 @@ mod tests {
fn test_partial_mmr_inner_nodes_iterator() { fn test_partial_mmr_inner_nodes_iterator() {
// build the MMR // build the MMR
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0]; let first_peak = mmr.peaks().peaks()[0];
// -- test single tree ---------------------------- // -- test single tree ----------------------------
// get path and node for position 1 // get path and node for position 1
let node1 = mmr.get(1).unwrap(); let node1 = mmr.get(1).unwrap();
let proof1 = mmr.open(1, mmr.forest()).unwrap(); let proof1 = mmr.open(1).unwrap();
// create partial MMR and add authentication path to node at position 1 // create partial MMR and add authentication path to node at position 1
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); let mut partial_mmr: PartialMmr = mmr.peaks().into();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap(); partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
// empty iterator should have no nodes // empty iterator should have no nodes
@@ -789,13 +822,13 @@ mod tests {
// -- test no duplicates -------------------------- // -- test no duplicates --------------------------
// build the partial MMR // build the partial MMR
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); let mut partial_mmr: PartialMmr = mmr.peaks().into();
let node0 = mmr.get(0).unwrap(); let node0 = mmr.get(0).unwrap();
let proof0 = mmr.open(0, mmr.forest()).unwrap(); let proof0 = mmr.open(0).unwrap();
let node2 = mmr.get(2).unwrap(); let node2 = mmr.get(2).unwrap();
let proof2 = mmr.open(2, mmr.forest()).unwrap(); let proof2 = mmr.open(2).unwrap();
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap(); partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap(); partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
@@ -826,10 +859,10 @@ mod tests {
// -- test multiple trees ------------------------- // -- test multiple trees -------------------------
// build the partial MMR // build the partial MMR
let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); let mut partial_mmr: PartialMmr = mmr.peaks().into();
let node5 = mmr.get(5).unwrap(); let node5 = mmr.get(5).unwrap();
let proof5 = mmr.open(5, mmr.forest()).unwrap(); let proof5 = mmr.open(5).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap(); partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap(); partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
@@ -841,7 +874,7 @@ mod tests {
let index1 = NodeIndex::new(2, 1).unwrap(); let index1 = NodeIndex::new(2, 1).unwrap();
let index5 = NodeIndex::new(1, 1).unwrap(); let index5 = NodeIndex::new(1, 1).unwrap();
let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1]; let second_peak = mmr.peaks().peaks()[1];
let path1 = store.get_path(first_peak, index1).unwrap().path; let path1 = store.get_path(first_peak, index1).unwrap().path;
let path5 = store.get_path(second_peak, index5).unwrap().path; let path5 = store.get_path(second_peak, index5).unwrap().path;
@@ -860,8 +893,7 @@ mod tests {
mmr.add(el); mmr.add(el);
partial_mmr.add(el, false); partial_mmr.add(el, false);
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(mmr.peaks(), partial_mmr.peaks());
assert_eq!(mmr_peaks, partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest()); assert_eq!(mmr.forest(), partial_mmr.forest());
} }
} }
@@ -877,12 +909,11 @@ mod tests {
mmr.add(el); mmr.add(el);
partial_mmr.add(el, true); partial_mmr.add(el, true);
let mmr_peaks = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(mmr.peaks(), partial_mmr.peaks());
assert_eq!(mmr_peaks, partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest()); assert_eq!(mmr.forest(), partial_mmr.forest());
for pos in 0..i { for pos in 0..i {
let mmr_proof = mmr.open(pos as usize, mmr.forest()).unwrap(); let mmr_proof = mmr.open(pos as usize).unwrap();
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap(); let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
assert_eq!(mmr_proof, partialmmr_proof); assert_eq!(mmr_proof, partialmmr_proof);
} }
@@ -894,8 +925,8 @@ mod tests {
let mut mmr = Mmr::from((0..7).map(int_to_node)); let mut mmr = Mmr::from((0..7).map(int_to_node));
// derive a partial Mmr from it which tracks authentication path to leaf 5 // derive a partial Mmr from it which tracks authentication path to leaf 5
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks(mmr.forest()).unwrap()); let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
let path_to_5 = mmr.open(5, mmr.forest()).unwrap().merkle_path; let path_to_5 = mmr.open(5).unwrap().merkle_path;
let leaf_at_5 = mmr.get(5).unwrap(); let leaf_at_5 = mmr.get(5).unwrap();
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap(); partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
@@ -905,6 +936,17 @@ mod tests {
partial_mmr.add(leaf_at_7, false); partial_mmr.add(leaf_at_7, false);
// the openings should be the same // the openings should be the same
assert_eq!(mmr.open(5, mmr.forest()).unwrap(), partial_mmr.open(5).unwrap().unwrap()); assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
}
#[test]
fn test_partial_mmr_serialization() {
let mmr = Mmr::from((0..7).map(int_to_node));
let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
let bytes = partial_mmr.to_bytes();
let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
assert_eq!(partial_mmr, decoded);
} }
} }

View File

@@ -1,6 +1,7 @@
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
use alloc::vec::Vec; use alloc::vec::Vec;
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
// MMR PEAKS // MMR PEAKS
// ================================================================================================ // ================================================================================================
@@ -18,12 +19,12 @@ pub struct MmrPeaks {
/// ///
/// Examples: /// Examples:
/// ///
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number /// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number of peaks,
/// of peaks, in this case there are 2 peaks. The 0-indexed least-significant position of /// in this case there are 2 peaks. The 0-indexed least-significant position of the bit
/// the bit determines the number of elements of a tree, so the rightmost tree has `2**0` /// determines the number of elements of a tree, so the rightmost tree has `2**0` elements
/// elements and the left most has `2**2`. /// and the left most has `2**2`.
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the /// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the leftmost tree has
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements. /// `2**3=8` elements, and the right most has `2**2=4` elements.
num_leaves: usize, num_leaves: usize,
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of /// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
@@ -44,7 +45,11 @@ impl MmrPeaks {
/// Returns an error if the number of leaves and the number of peaks are inconsistent. /// Returns an error if the number of leaves and the number of peaks are inconsistent.
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> { pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
if num_leaves.count_ones() as usize != peaks.len() { if num_leaves.count_ones() as usize != peaks.len() {
return Err(MmrError::InvalidPeaks); return Err(MmrError::InvalidPeaks(format!(
"number of one bits in leaves is {} which does not equal peak length {}",
num_leaves.count_ones(),
peaks.len()
)));
} }
Ok(Self { num_leaves, peaks }) Ok(Self { num_leaves, peaks })
@@ -68,6 +73,17 @@ impl MmrPeaks {
&self.peaks &self.peaks
} }
/// Returns the peak by the provided index.
///
/// # Errors
/// Returns an error if the provided peak index is greater or equal to the current number of
/// peaks in the Mmr.
pub fn get_peak(&self, peak_idx: usize) -> Result<&RpoDigest, MmrError> {
self.peaks
.get(peak_idx)
.ok_or(MmrError::PeakOutOfBounds { peak_idx, peaks_len: self.peaks.len() })
}
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of /// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
/// the underlying MMR. /// the underlying MMR.
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) { pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
@@ -83,9 +99,18 @@ impl MmrPeaks {
Rpo256::hash_elements(&self.flatten_and_pad_peaks()) Rpo256::hash_elements(&self.flatten_and_pad_peaks())
} }
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool { /// Verifies the Merkle opening proof.
let root = &self.peaks[opening.peak_index()]; ///
opening.merkle_path.verify(opening.relative_pos() as u64, value, root) /// # Errors
/// Returns an error if:
/// - provided opening proof is invalid.
/// - Mmr root value computed using the provided leaf value differs from the actual one.
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> Result<(), MmrError> {
let root = self.get_peak(opening.peak_index())?;
opening
.merkle_path
.verify(opening.relative_pos() as u64, value, root)
.map_err(MmrError::InvalidMerklePath)
} }
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier. /// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
@@ -94,16 +119,15 @@ impl MmrPeaks {
/// - Flatten the vector of Words into a vector of Felts. /// - Flatten the vector of Words into a vector of Felts.
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO /// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
/// padding. /// padding.
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of /// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of hashing.
/// hashing.
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> { pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
let num_peaks = self.peaks.len(); let num_peaks = self.peaks.len();
// To achieve the padding rules above we calculate the length of the final vector. // To achieve the padding rules above we calculate the length of the final vector.
// This is calculated as the number of field elements. Each peak is 4 field elements. // This is calculated as the number of field elements. Each peak is 4 field elements.
// The length is calculated as follows: // The length is calculated as follows:
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires // - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires 64
// 64 field elements. // field elements.
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to // - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements. // an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded // - If there are more than 16 peaks and the number of peaks is even, the data is not padded

View File

@@ -1,3 +1,5 @@
use alloc::vec::Vec;
use super::{ use super::{
super::{InnerNodeInfo, Rpo256, RpoDigest}, super::{InnerNodeInfo, Rpo256, RpoDigest},
bit::TrueBitPositionIterator, bit::TrueBitPositionIterator,
@@ -8,7 +10,6 @@ use crate::{
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex}, merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
Felt, Word, Felt, Word,
}; };
use alloc::vec::Vec;
#[test] #[test]
fn test_position_equal_or_higher_than_leafs_is_never_contained() { fn test_position_equal_or_higher_than_leafs_is_never_contained() {
@@ -138,7 +139,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 1); assert_eq!(mmr.nodes.len(), 1);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 1); assert_eq!(acc.num_leaves(), 1);
assert_eq!(acc.peaks(), &[postorder[0]]); assert_eq!(acc.peaks(), &[postorder[0]]);
@@ -147,7 +148,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 3); assert_eq!(mmr.nodes.len(), 3);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 2); assert_eq!(acc.num_leaves(), 2);
assert_eq!(acc.peaks(), &[postorder[2]]); assert_eq!(acc.peaks(), &[postorder[2]]);
@@ -156,7 +157,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 4); assert_eq!(mmr.nodes.len(), 4);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 3); assert_eq!(acc.num_leaves(), 3);
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]); assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
@@ -165,7 +166,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 7); assert_eq!(mmr.nodes.len(), 7);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 4); assert_eq!(acc.num_leaves(), 4);
assert_eq!(acc.peaks(), &[postorder[6]]); assert_eq!(acc.peaks(), &[postorder[6]]);
@@ -174,7 +175,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 8); assert_eq!(mmr.nodes.len(), 8);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 5); assert_eq!(acc.num_leaves(), 5);
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]); assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
@@ -183,7 +184,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 10); assert_eq!(mmr.nodes.len(), 10);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 6); assert_eq!(acc.num_leaves(), 6);
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]); assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
@@ -192,7 +193,7 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 11); assert_eq!(mmr.nodes.len(), 11);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 7); assert_eq!(acc.num_leaves(), 7);
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]); assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
} }
@@ -204,97 +205,73 @@ fn test_mmr_open() {
let h23 = merge(LEAVES[2], LEAVES[3]); let h23 = merge(LEAVES[2], LEAVES[3]);
// node at pos 7 is the root // node at pos 7 is the root
assert!( assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
mmr.open(7, mmr.forest()).is_err(),
"Element 7 is not in the tree, result should be None"
);
// node at pos 6 is the root // node at pos 6 is the root
let empty: MerklePath = MerklePath::new(vec![]); let empty: MerklePath = MerklePath::new(vec![]);
let opening = mmr let opening = mmr
.open(6, mmr.forest()) .open(6)
.expect("Element 6 is contained in the tree, expected an opening result."); .expect("Element 6 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, empty); assert_eq!(opening.merkle_path, empty);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 6); assert_eq!(opening.position, 6);
assert!( mmr.peaks().verify(LEAVES[6], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening),
"MmrProof should be valid for the current accumulator."
);
// nodes 4,5 are depth 1 // nodes 4,5 are depth 1
let root_to_path = MerklePath::new(vec![LEAVES[4]]); let root_to_path = MerklePath::new(vec![LEAVES[4]]);
let opening = mmr let opening = mmr
.open(5, mmr.forest()) .open(5)
.expect("Element 5 is contained in the tree, expected an opening result."); .expect("Element 5 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 5); assert_eq!(opening.position, 5);
assert!( mmr.peaks().verify(LEAVES[5], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening),
"MmrProof should be valid for the current accumulator."
);
let root_to_path = MerklePath::new(vec![LEAVES[5]]); let root_to_path = MerklePath::new(vec![LEAVES[5]]);
let opening = mmr let opening = mmr
.open(4, mmr.forest()) .open(4)
.expect("Element 4 is contained in the tree, expected an opening result."); .expect("Element 4 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 4); assert_eq!(opening.position, 4);
assert!( mmr.peaks().verify(LEAVES[4], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening),
"MmrProof should be valid for the current accumulator."
);
// nodes 0,1,2,3 are detph 2 // nodes 0,1,2,3 are detph 2
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]); let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
let opening = mmr let opening = mmr
.open(3, mmr.forest()) .open(3)
.expect("Element 3 is contained in the tree, expected an opening result."); .expect("Element 3 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 3); assert_eq!(opening.position, 3);
assert!( mmr.peaks().verify(LEAVES[3], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening),
"MmrProof should be valid for the current accumulator."
);
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]); let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
let opening = mmr let opening = mmr
.open(2, mmr.forest()) .open(2)
.expect("Element 2 is contained in the tree, expected an opening result."); .expect("Element 2 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 2); assert_eq!(opening.position, 2);
assert!( mmr.peaks().verify(LEAVES[2], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening),
"MmrProof should be valid for the current accumulator."
);
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]); let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
let opening = mmr let opening = mmr
.open(1, mmr.forest()) .open(1)
.expect("Element 1 is contained in the tree, expected an opening result."); .expect("Element 1 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 1); assert_eq!(opening.position, 1);
assert!( mmr.peaks().verify(LEAVES[1], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening),
"MmrProof should be valid for the current accumulator."
);
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]); let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
let opening = mmr let opening = mmr
.open(0, mmr.forest()) .open(0)
.expect("Element 0 is contained in the tree, expected an opening result."); .expect("Element 0 is contained in the tree, expected an opening result.");
assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 0); assert_eq!(opening.position, 0);
assert!( mmr.peaks().verify(LEAVES[0], opening).unwrap();
mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening),
"MmrProof should be valid for the current accumulator."
);
} }
#[test] #[test]
@@ -308,7 +285,7 @@ fn test_mmr_open_older_version() {
// merkle path of a node is empty if there are no elements to pair with it // merkle path of a node is empty if there are no elements to pair with it
for pos in (0..mmr.forest()).filter(is_even) { for pos in (0..mmr.forest()).filter(is_even) {
let forest = pos + 1; let forest = pos + 1;
let proof = mmr.open(pos, forest).unwrap(); let proof = mmr.open_at(pos, forest).unwrap();
assert_eq!(proof.forest, forest); assert_eq!(proof.forest, forest);
assert_eq!(proof.merkle_path.nodes(), []); assert_eq!(proof.merkle_path.nodes(), []);
assert_eq!(proof.position, pos); assert_eq!(proof.position, pos);
@@ -320,7 +297,7 @@ fn test_mmr_open_older_version() {
for pos in 0..4 { for pos in 0..4 {
let idx = NodeIndex::new(2, pos).unwrap(); let idx = NodeIndex::new(2, pos).unwrap();
let path = mtree.get_path(idx).unwrap(); let path = mtree.get_path(idx).unwrap();
let proof = mmr.open(pos as usize, forest).unwrap(); let proof = mmr.open_at(pos as usize, forest).unwrap();
assert_eq!(path, proof.merkle_path); assert_eq!(path, proof.merkle_path);
} }
} }
@@ -331,7 +308,7 @@ fn test_mmr_open_older_version() {
let path = mtree.get_path(idx).unwrap(); let path = mtree.get_path(idx).unwrap();
// account for the bigger tree with 4 elements // account for the bigger tree with 4 elements
let mmr_pos = (pos + 4) as usize; let mmr_pos = (pos + 4) as usize;
let proof = mmr.open(mmr_pos, forest).unwrap(); let proof = mmr.open_at(mmr_pos, forest).unwrap();
assert_eq!(path, proof.merkle_path); assert_eq!(path, proof.merkle_path);
} }
} }
@@ -357,49 +334,49 @@ fn test_mmr_open_eight() {
let root = mtree.root(); let root = mtree.root();
let position = 0; let position = 0;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 1; let position = 1;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 2; let position = 2;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 3; let position = 3;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 4; let position = 4;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 5; let position = 5;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 6; let position = 6;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 7; let position = 7;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
@@ -415,47 +392,47 @@ fn test_mmr_open_seven() {
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let position = 0; let position = 0;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root()); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
let position = 1; let position = 1;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root()); assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
let position = 2; let position = 2;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root()); assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
let position = 3; let position = 3;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root()); assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
let position = 4; let position = 4;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap(); let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root()); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
let position = 5; let position = 5;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap(); let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root()); assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
let position = 6; let position = 6;
let proof = mmr.open(position, mmr.forest()).unwrap(); let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = [].as_ref().into(); let merkle_path: MerklePath = [].as_ref().into();
assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
@@ -479,7 +456,7 @@ fn test_mmr_invariants() {
let mut mmr = Mmr::new(); let mut mmr = Mmr::new();
for v in 1..=1028 { for v in 1..=1028 {
mmr.add(int_to_node(v)); mmr.add(int_to_node(v));
let accumulator = mmr.peaks(mmr.forest()).unwrap(); let accumulator = mmr.peaks();
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add"); assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
assert_eq!( assert_eq!(
v as usize, v as usize,
@@ -565,37 +542,37 @@ fn test_mmr_peaks() {
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let forest = 0b0001; let forest = 0b0001;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[0]]); assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
let forest = 0b0010; let forest = 0b0010;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[2]]); assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
let forest = 0b0011; let forest = 0b0011;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]); assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
let forest = 0b0100; let forest = 0b0100;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6]]); assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
let forest = 0b0101; let forest = 0b0101;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]); assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
let forest = 0b0110; let forest = 0b0110;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]); assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
let forest = 0b0111; let forest = 0b0111;
let acc = mmr.peaks(forest).unwrap(); let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]); assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
} }
@@ -603,7 +580,7 @@ fn test_mmr_peaks() {
#[test] #[test]
fn test_mmr_hash_peaks() { fn test_mmr_hash_peaks() {
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let peaks = mmr.peaks(mmr.forest()).unwrap(); let peaks = mmr.peaks();
let first_peak = Rpo256::merge(&[ let first_peak = Rpo256::merge(&[
Rpo256::merge(&[LEAVES[0], LEAVES[1]]), Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
@@ -657,7 +634,7 @@ fn test_mmr_peaks_hash_odd() {
#[test] #[test]
fn test_mmr_delta() { fn test_mmr_delta() {
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
// original_forest can't have more elements // original_forest can't have more elements
assert!( assert!(
@@ -757,7 +734,7 @@ fn test_mmr_delta_old_forest() {
#[test] #[test]
fn test_partial_mmr_simple() { fn test_partial_mmr_simple() {
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let peaks = mmr.peaks(mmr.forest()).unwrap(); let peaks = mmr.peaks();
let mut partial: PartialMmr = peaks.clone().into(); let mut partial: PartialMmr = peaks.clone().into();
// check initial state of the partial mmr // check initial state of the partial mmr
@@ -768,7 +745,7 @@ fn test_partial_mmr_simple() {
assert_eq!(partial.nodes.len(), 0); assert_eq!(partial.nodes.len(), 0);
// check state after adding tracking one element // check state after adding tracking one element
let proof1 = mmr.open(0, mmr.forest()).unwrap(); let proof1 = mmr.open(0).unwrap();
let el1 = mmr.get(proof1.position).unwrap(); let el1 = mmr.get(proof1.position).unwrap();
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap(); partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
@@ -780,7 +757,7 @@ fn test_partial_mmr_simple() {
let idx = idx.parent(); let idx = idx.parent();
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]); assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
let proof2 = mmr.open(1, mmr.forest()).unwrap(); let proof2 = mmr.open(1).unwrap();
let el2 = mmr.get(proof2.position).unwrap(); let el2 = mmr.get(proof2.position).unwrap();
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap(); partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
@@ -798,9 +775,9 @@ fn test_partial_mmr_update_single() {
let mut full = Mmr::new(); let mut full = Mmr::new();
let zero = int_to_node(0); let zero = int_to_node(0);
full.add(zero); full.add(zero);
let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into(); let mut partial: PartialMmr = full.peaks().into();
let proof = full.open(0, full.forest()).unwrap(); let proof = full.open(0).unwrap();
partial.track(proof.position, zero, &proof.merkle_path).unwrap(); partial.track(proof.position, zero, &proof.merkle_path).unwrap();
for i in 1..100 { for i in 1..100 {
@@ -810,9 +787,9 @@ fn test_partial_mmr_update_single() {
partial.apply(delta).unwrap(); partial.apply(delta).unwrap();
assert_eq!(partial.forest(), full.forest()); assert_eq!(partial.forest(), full.forest());
assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap()); assert_eq!(partial.peaks(), full.peaks());
let proof1 = full.open(i as usize, full.forest()).unwrap(); let proof1 = full.open(i as usize).unwrap();
partial.track(proof1.position, node, &proof1.merkle_path).unwrap(); partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
let proof2 = partial.open(proof1.position).unwrap().unwrap(); let proof2 = partial.open(proof1.position).unwrap().unwrap();
assert_eq!(proof1.merkle_path, proof2.merkle_path); assert_eq!(proof1.merkle_path, proof2.merkle_path);
@@ -822,7 +799,7 @@ fn test_partial_mmr_update_single() {
#[test] #[test]
fn test_mmr_add_invalid_odd_leaf() { fn test_mmr_add_invalid_odd_leaf() {
let mmr: Mmr = LEAVES.into(); let mmr: Mmr = LEAVES.into();
let acc = mmr.peaks(mmr.forest()).unwrap(); let acc = mmr.peaks();
let mut partial: PartialMmr = acc.clone().into(); let mut partial: PartialMmr = acc.clone().into();
let empty = MerklePath::new(Vec::new()); let empty = MerklePath::new(Vec::new());
@@ -837,6 +814,39 @@ fn test_mmr_add_invalid_odd_leaf() {
assert!(result.is_ok()); assert!(result.is_ok());
} }
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
///
/// Here we manipulate the proof to return a peak index of 1 while the MMR only has 1 peak (with
/// index 0).
#[test]
#[should_panic]
fn test_mmr_proof_num_peaks_exceeds_current_num_peaks() {
let mmr: Mmr = LEAVES[0..4].iter().cloned().into();
let mut proof = mmr.open(3).unwrap();
proof.forest = 5;
proof.position = 4;
mmr.peaks().verify(LEAVES[3], proof).unwrap();
}
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
///
/// We create an MmrProof for a leaf whose peak index to verify against is 1.
/// Then we add another leaf which results in an Mmr with just one peak due to trees
/// being merged. If we try to use the old proof against the new Mmr, we should get an error.
#[test]
#[should_panic]
fn test_mmr_old_proof_num_peaks_exceeds_current_num_peaks() {
let leaves_len = 3;
let mut mmr = Mmr::from(LEAVES[0..leaves_len].iter().cloned());
let leaf_idx = leaves_len - 1;
let proof = mmr.open(leaf_idx).unwrap();
assert!(mmr.peaks().verify(LEAVES[leaf_idx], proof.clone()).is_ok());
mmr.add(LEAVES[leaves_len]);
mmr.peaks().verify(LEAVES[leaf_idx], proof).unwrap();
}
mod property_tests { mod property_tests {
use proptest::prelude::*; use proptest::prelude::*;

View File

@@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath};
mod smt; mod smt;
pub use smt::{ pub use smt::{
LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH, InnerNode, LeafIndex, MutationSet, NodeMutation, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
SMT_MAX_DEPTH, SMT_MIN_DEPTH, SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
}; };
mod mmr; mod mmr;

View File

@@ -116,7 +116,7 @@ impl PartialMerkleTree {
// depth of 63 because we consider passing in a vector of size 2^64 infeasible. // depth of 63 because we consider passing in a vector of size 2^64 infeasible.
let max = 2usize.pow(63); let max = 2usize.pow(63);
if layers.len() > max { if layers.len() > max {
return Err(MerkleError::InvalidNumEntries(max)); return Err(MerkleError::TooManyEntries(max));
} }
// Get maximum depth // Get maximum depth
@@ -147,11 +147,12 @@ impl PartialMerkleTree {
let index = NodeIndex::new(depth, index_value)?; let index = NodeIndex::new(depth, index_value)?;
// get hash of the current node // get hash of the current node
let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?; let node =
nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
// get hash of the sibling node // get hash of the sibling node
let sibling = nodes let sibling = nodes
.get(&index.sibling()) .get(&index.sibling())
.ok_or(MerkleError::NodeNotInSet(index.sibling()))?; .ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
// get parent hash // get parent hash
let parent = Rpo256::merge(&index.build_node(*node, *sibling)); let parent = Rpo256::merge(&index.build_node(*node, *sibling));
@@ -184,7 +185,10 @@ impl PartialMerkleTree {
/// # Errors /// # Errors
/// Returns an error if the specified NodeIndex is not contained in the nodes map. /// Returns an error if the specified NodeIndex is not contained in the nodes map.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> { pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).copied() self.nodes
.get(&index)
.ok_or(MerkleError::NodeIndexNotFoundInTree(index))
.copied()
} }
/// Returns true if provided index contains in the leaves set, false otherwise. /// Returns true if provided index contains in the leaves set, false otherwise.
@@ -214,7 +218,7 @@ impl PartialMerkleTree {
/// # Errors /// # Errors
/// Returns an error if: /// Returns an error if:
/// - the specified index has depth set to 0 or the depth is greater than the depth of this /// - the specified index has depth set to 0 or the depth is greater than the depth of this
/// Merkle tree. /// Merkle tree.
/// - the specified index is not contained in the nodes map. /// - the specified index is not contained in the nodes map.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> { pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() { if index.is_root() {
@@ -224,7 +228,7 @@ impl PartialMerkleTree {
} }
if !self.nodes.contains_key(&index) { if !self.nodes.contains_key(&index) {
return Err(MerkleError::NodeNotInSet(index)); return Err(MerkleError::NodeIndexNotFoundInTree(index));
} }
let mut path = Vec::new(); let mut path = Vec::new();
@@ -335,15 +339,16 @@ impl PartialMerkleTree {
if self.root() == EMPTY_DIGEST { if self.root() == EMPTY_DIGEST {
self.nodes.insert(ROOT_INDEX, root); self.nodes.insert(ROOT_INDEX, root);
} else if self.root() != root { } else if self.root() != root {
return Err(MerkleError::ConflictingRoots([self.root(), root].to_vec())); return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: root,
});
} }
Ok(()) Ok(())
} }
/// Updates value of the leaf at the specified index returning the old leaf value. /// Updates value of the leaf at the specified index returning the old leaf value.
/// By default the specified index is assumed to belong to the deepest layer. If the considered
/// node does not belong to the tree, the first node on the way to the root will be changed.
/// ///
/// By default the specified index is assumed to belong to the deepest layer. If the considered /// By default the specified index is assumed to belong to the deepest layer. If the considered
/// node does not belong to the tree, the first node on the way to the root will be changed. /// node does not belong to the tree, the first node on the way to the root will be changed.
@@ -352,6 +357,7 @@ impl PartialMerkleTree {
/// ///
/// # Errors /// # Errors
/// Returns an error if: /// Returns an error if:
/// - No entry exists at the specified index.
/// - The specified index is greater than the maximum number of nodes on the deepest layer. /// - The specified index is greater than the maximum number of nodes on the deepest layer.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> { pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
let mut node_index = NodeIndex::new(self.max_depth(), index)?; let mut node_index = NodeIndex::new(self.max_depth(), index)?;
@@ -367,7 +373,7 @@ impl PartialMerkleTree {
let old_value = self let old_value = self
.nodes .nodes
.insert(node_index, value.into()) .insert(node_index, value.into())
.ok_or(MerkleError::NodeNotInSet(node_index))?; .ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
// if the old value and new value are the same, there is nothing to update // if the old value and new value are the same, there is nothing to update
if value == *old_value { if value == *old_value {

View File

@@ -1,3 +1,5 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{ use super::{
super::{ super::{
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
@@ -5,7 +7,6 @@ use super::{
}, },
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath,
}; };
use alloc::{collections::BTreeMap, vec::Vec};
// TEST DATA // TEST DATA
// ================================================================================================ // ================================================================================================
@@ -294,7 +295,8 @@ fn leaves() {
assert!(expected_leaves.eq(pmt.leaves())); assert!(expected_leaves.eq(pmt.leaves()));
} }
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones. /// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected
/// ones.
#[test] #[test]
fn test_inner_node_iterator() { fn test_inner_node_iterator() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap(); let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();

View File

@@ -54,12 +54,20 @@ impl MerklePath {
/// Verifies the Merkle opening proof towards the provided root. /// Verifies the Merkle opening proof towards the provided root.
/// ///
/// Returns `true` if `node` exists at `index` in a Merkle tree with `root`. /// # Errors
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> bool { /// Returns an error if:
match self.compute_root(index, node) { /// - provided node index is invalid.
Ok(computed_root) => root == &computed_root, /// - root calculated during the verification differs from the provided one.
Err(_) => false, pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> Result<(), MerkleError> {
let computed_root = self.compute_root(index, node)?;
if &computed_root != root {
return Err(MerkleError::ConflictingRoots {
expected_root: *root,
actual_root: computed_root,
});
} }
Ok(())
} }
/// Returns an iterator over every inner node of this [MerklePath]. /// Returns an iterator over every inner node of this [MerklePath].
@@ -143,7 +151,7 @@ pub struct InnerNodeIterator<'a> {
value: RpoDigest, value: RpoDigest,
} }
impl<'a> Iterator for InnerNodeIterator<'a> { impl Iterator for InnerNodeIterator<'_> {
type Item = InnerNodeInfo; type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {

View File

@@ -1,86 +1,39 @@
use alloc::vec::Vec; use thiserror::Error;
use core::fmt;
use crate::{ use crate::{
hash::rpo::RpoDigest, hash::rpo::RpoDigest,
merkle::{LeafIndex, SMT_DEPTH}, merkle::{LeafIndex, SMT_DEPTH},
Word,
}; };
// SMT LEAF ERROR // SMT LEAF ERROR
// ================================================================================================= // =================================================================================================
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Debug, Error)]
pub enum SmtLeafError { pub enum SmtLeafError {
InconsistentKeys { #[error(
entries: Vec<(RpoDigest, Word)>, "multiple leaf requires all keys to map to the same leaf index but key1 {key_1} and key2 {key_2} map to different indices"
key_1: RpoDigest, )]
key_2: RpoDigest, InconsistentMultipleLeafKeys { key_1: RpoDigest, key_2: RpoDigest },
}, #[error("single leaf key {key} maps to {actual_leaf_index:?} but was expected to map to {expected_leaf_index:?}")]
InvalidNumEntriesForMultiple(usize), InconsistentSingleLeafIndices {
SingleKeyInconsistentWithLeafIndex {
key: RpoDigest, key: RpoDigest,
leaf_index: LeafIndex<SMT_DEPTH>, expected_leaf_index: LeafIndex<SMT_DEPTH>,
actual_leaf_index: LeafIndex<SMT_DEPTH>,
}, },
MultipleKeysInconsistentWithLeafIndex { #[error("supplied leaf index {leaf_index_supplied:?} does not match {leaf_index_from_keys:?} for multiple leaf")]
InconsistentMultipleLeafIndices {
leaf_index_from_keys: LeafIndex<SMT_DEPTH>, leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
leaf_index_supplied: LeafIndex<SMT_DEPTH>, leaf_index_supplied: LeafIndex<SMT_DEPTH>,
}, },
} #[error("multiple leaf requires at least two entries but only {0} were given")]
MultipleLeafRequiresTwoEntries(usize),
#[cfg(feature = "std")]
impl std::error::Error for SmtLeafError {}
impl fmt::Display for SmtLeafError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use SmtLeafError::*;
match self {
InvalidNumEntriesForMultiple(num_entries) => {
write!(f, "Multiple leaf requires 2 or more entries. Got: {num_entries}")
}
InconsistentKeys { entries, key_1, key_2 } => {
write!(f, "Multiple leaf requires all keys to map to the same leaf index. Offending keys: {key_1} and {key_2}. Entries: {entries:?}.")
}
SingleKeyInconsistentWithLeafIndex { key, leaf_index } => {
write!(
f,
"Single key in leaf inconsistent with leaf index. Key: {key}, leaf index: {}",
leaf_index.value()
)
}
MultipleKeysInconsistentWithLeafIndex {
leaf_index_from_keys,
leaf_index_supplied,
} => {
write!(
f,
"Keys in entries map to leaf index {}, but leaf index {} was supplied",
leaf_index_from_keys.value(),
leaf_index_supplied.value()
)
}
}
}
} }
// SMT PROOF ERROR // SMT PROOF ERROR
// ================================================================================================= // =================================================================================================
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Debug, Error)]
pub enum SmtProofError { pub enum SmtProofError {
InvalidPathLength(usize), #[error("merkle path length {0} does not match SMT depth {SMT_DEPTH}")]
} InvalidMerklePathLength(usize),
#[cfg(feature = "std")]
impl std::error::Error for SmtProofError {}
impl fmt::Display for SmtProofError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use SmtProofError::*;
match self {
InvalidPathLength(path_length) => {
write!(f, "Invalid Merkle path length. Expected {SMT_DEPTH}, got {path_length}")
}
}
}
} }

View File

@@ -20,8 +20,8 @@ impl SmtLeaf {
/// ///
/// # Errors /// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index /// - Returns an error if 2 keys in `entries` map to a different leaf index
/// - Returns an error if 1 or more keys in `entries` map to a leaf index /// - Returns an error if 1 or more keys in `entries` map to a leaf index different from
/// different from `leaf_index` /// `leaf_index`
pub fn new( pub fn new(
entries: Vec<(RpoDigest, Word)>, entries: Vec<(RpoDigest, Word)>,
leaf_index: LeafIndex<SMT_DEPTH>, leaf_index: LeafIndex<SMT_DEPTH>,
@@ -31,29 +31,31 @@ impl SmtLeaf {
1 => { 1 => {
let (key, value) = entries[0]; let (key, value) = entries[0];
if LeafIndex::<SMT_DEPTH>::from(key) != leaf_index { let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
return Err(SmtLeafError::SingleKeyInconsistentWithLeafIndex { if computed_index != leaf_index {
return Err(SmtLeafError::InconsistentSingleLeafIndices {
key, key,
leaf_index, expected_leaf_index: leaf_index,
actual_leaf_index: computed_index,
}); });
} }
Ok(Self::new_single(key, value)) Ok(Self::new_single(key, value))
} },
_ => { _ => {
let leaf = Self::new_multiple(entries)?; let leaf = Self::new_multiple(entries)?;
// `new_multiple()` checked that all keys map to the same leaf index. We still need // `new_multiple()` checked that all keys map to the same leaf index. We still need
// to ensure that that leaf index is `leaf_index`. // to ensure that that leaf index is `leaf_index`.
if leaf.index() != leaf_index { if leaf.index() != leaf_index {
Err(SmtLeafError::MultipleKeysInconsistentWithLeafIndex { Err(SmtLeafError::InconsistentMultipleLeafIndices {
leaf_index_from_keys: leaf.index(), leaf_index_from_keys: leaf.index(),
leaf_index_supplied: leaf_index, leaf_index_supplied: leaf_index,
}) })
} else { } else {
Ok(leaf) Ok(leaf)
} }
} },
} }
} }
@@ -68,14 +70,14 @@ impl SmtLeaf {
Self::Single((key, value)) Self::Single((key, value))
} }
/// Returns a new single leaf with the specified entry. The leaf index is derived from the /// Returns a new multiple leaf with the specified entries. The leaf index is derived from the
/// entries' keys. /// entries' keys.
/// ///
/// # Errors /// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index /// - Returns an error if 2 keys in `entries` map to a different leaf index
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> { pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
if entries.len() < 2 { if entries.len() < 2 {
return Err(SmtLeafError::InvalidNumEntriesForMultiple(entries.len())); return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
} }
// Check that all keys map to the same leaf index // Check that all keys map to the same leaf index
@@ -89,8 +91,7 @@ impl SmtLeaf {
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into(); let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
if next_leaf_index != first_leaf_index { if next_leaf_index != first_leaf_index {
return Err(SmtLeafError::InconsistentKeys { return Err(SmtLeafError::InconsistentMultipleLeafKeys {
entries,
key_1: first_key, key_1: first_key,
key_2: next_key, key_2: next_key,
}); });
@@ -118,7 +119,7 @@ impl SmtLeaf {
// Note: All keys are guaranteed to have the same leaf index // Note: All keys are guaranteed to have the same leaf index
let (first_key, _) = entries[0]; let (first_key, _) = entries[0];
first_key.into() first_key.into()
} },
} }
} }
@@ -129,7 +130,7 @@ impl SmtLeaf {
SmtLeaf::Single(_) => 1, SmtLeaf::Single(_) => 1,
SmtLeaf::Multiple(entries) => { SmtLeaf::Multiple(entries) => {
entries.len().try_into().expect("shouldn't have more than 2^64 entries") entries.len().try_into().expect("shouldn't have more than 2^64 entries")
} },
} }
} }
@@ -141,7 +142,7 @@ impl SmtLeaf {
SmtLeaf::Multiple(kvs) => { SmtLeaf::Multiple(kvs) => {
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect(); let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
Rpo256::hash_elements(&elements) Rpo256::hash_elements(&elements)
} },
} }
} }
@@ -182,7 +183,8 @@ impl SmtLeaf {
// HELPERS // HELPERS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another leaf. /// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
/// leaf.
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> { pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
// Ensure that `key` maps to this leaf // Ensure that `key` maps to this leaf
if self.index() != key.into() { if self.index() != key.into() {
@@ -197,7 +199,7 @@ impl SmtLeaf {
} else { } else {
Some(EMPTY_WORD) Some(EMPTY_WORD)
} }
} },
SmtLeaf::Multiple(kv_pairs) => { SmtLeaf::Multiple(kv_pairs) => {
for (key_in_leaf, value_in_leaf) in kv_pairs { for (key_in_leaf, value_in_leaf) in kv_pairs {
if key == key_in_leaf { if key == key_in_leaf {
@@ -206,7 +208,7 @@ impl SmtLeaf {
} }
Some(EMPTY_WORD) Some(EMPTY_WORD)
} },
} }
} }
@@ -219,7 +221,7 @@ impl SmtLeaf {
SmtLeaf::Empty(_) => { SmtLeaf::Empty(_) => {
*self = SmtLeaf::new_single(key, value); *self = SmtLeaf::new_single(key, value);
None None
} },
SmtLeaf::Single(kv_pair) => { SmtLeaf::Single(kv_pair) => {
if kv_pair.0 == key { if kv_pair.0 == key {
// the key is already in this leaf. Update the value and return the previous // the key is already in this leaf. Update the value and return the previous
@@ -237,7 +239,7 @@ impl SmtLeaf {
None None
} }
} },
SmtLeaf::Multiple(kv_pairs) => { SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) { match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => { Ok(pos) => {
@@ -245,14 +247,14 @@ impl SmtLeaf {
kv_pairs[pos].1 = value; kv_pairs[pos].1 = value;
Some(old_value) Some(old_value)
} },
Err(pos) => { Err(pos) => {
kv_pairs.insert(pos, (key, value)); kv_pairs.insert(pos, (key, value));
None None
} },
} }
} },
} }
} }
@@ -277,7 +279,7 @@ impl SmtLeaf {
// another key is stored at leaf; nothing to update // another key is stored at leaf; nothing to update
(None, false) (None, false)
} }
} },
SmtLeaf::Multiple(kv_pairs) => { SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) { match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => { Ok(pos) => {
@@ -292,13 +294,13 @@ impl SmtLeaf {
} }
(Some(old_value), false) (Some(old_value), false)
} },
Err(_) => { Err(_) => {
// other keys are stored at leaf; nothing to update // other keys are stored at leaf; nothing to update
(None, false) (None, false)
} },
} }
} },
} }
} }
} }
@@ -349,7 +351,7 @@ impl Deserializable for SmtLeaf {
// ================================================================================================ // ================================================================================================
/// Converts a key-value tuple to an iterator of `Felt`s /// Converts a key-value tuple to an iterator of `Felt`s
fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> { pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
let key_elements = key.into_iter(); let key_elements = key.into_iter();
let value_elements = value.into_iter(); let value_elements = value.into_iter();
@@ -358,7 +360,7 @@ fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt>
/// Compares two keys, compared element-by-element using their integer representations starting with /// Compares two keys, compared element-by-element using their integer representations starting with
/// the most significant element. /// the most significant element.
fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() { for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
let v1 = v1.as_int(); let v1 = v1.as_int();
let v2 = v2.as_int(); let v2 = v2.as_int();

View File

@@ -1,13 +1,14 @@
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
use alloc::{ use alloc::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
string::ToString, string::ToString,
vec::Vec, vec::Vec,
}; };
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
mod error; mod error;
pub use error::{SmtLeafError, SmtProofError}; pub use error::{SmtLeafError, SmtProofError};
@@ -32,8 +33,8 @@ pub const SMT_DEPTH: u8 = 64;
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented /// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements. /// by 4 field elements.
/// ///
/// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf to /// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf
/// which the key maps. /// to which the key maps.
/// ///
/// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty /// A leaf is either empty, or holds one or more key-value pairs. An empty leaf hashes to the empty
/// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value /// word. Otherwise, a leaf hashes to the hash of its key-value pairs, ordered by key first, value
@@ -113,6 +114,11 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::root(self) <Self as SparseMerkleTree<SMT_DEPTH>>::root(self)
} }
/// Returns the number of non-empty leaves in this tree.
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
/// Returns the leaf to which `key` maps /// Returns the leaf to which `key` maps
pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf { pub fn get_leaf(&self, key: &RpoDigest) -> SmtLeaf {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key) <Self as SparseMerkleTree<SMT_DEPTH>>::get_leaf(self, key)
@@ -120,12 +126,7 @@ impl Smt {
/// Returns the value associated with `key` /// Returns the value associated with `key`
pub fn get_value(&self, key: &RpoDigest) -> Word { pub fn get_value(&self, key: &RpoDigest) -> Word {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value(); <Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key)
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
} }
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
@@ -134,6 +135,12 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key) <Self as SparseMerkleTree<SMT_DEPTH>>::open(self, key)
} }
/// Returns a boolean value indicating whether the SMT is empty.
pub fn is_empty(&self) -> bool {
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
self.root == Self::EMPTY_ROOT
}
// ITERATORS // ITERATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -171,6 +178,64 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value) <Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
} }
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle
/// tree, or [`drop()`] to discard them.
///
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt = Smt::new();
/// let pair = (RpoDigest::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree
/// and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<MutationSet<SMT_DEPTH, RpoDigest, Word>, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
// HELPERS // HELPERS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -187,7 +252,7 @@ impl Smt {
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value))); self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None None
} },
} }
} }
@@ -215,6 +280,7 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
type Opening = SmtProof; type Opening = SmtProof;
const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn root(&self) -> RpoDigest { fn root(&self) -> RpoDigest {
self.root self.root
@@ -225,19 +291,18 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
} }
fn get_inner_node(&self, index: NodeIndex) -> InnerNode { fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| { self.inner_nodes
let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1); .get(&index)
.cloned()
InnerNode { left: *node, right: *node } .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
})
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node); self.inner_nodes.insert(index, inner_node)
} }
fn remove_inner_node(&mut self, index: NodeIndex) { fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
let _ = self.inner_nodes.remove(&index); self.inner_nodes.remove(&index)
} }
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> { fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
@@ -249,6 +314,15 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
} }
} }
fn get_value(&self, key: &Self::Key) -> Self::Value {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
}
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf { fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value(); let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
@@ -262,6 +336,28 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
leaf.hash() leaf.hash()
} }
fn construct_prospective_leaf(
&self,
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> { fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3]; let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_int()) LeafIndex::new_max_depth(most_significant_felt.as_int())
@@ -314,6 +410,14 @@ impl Serializable for Smt {
target.write(value); target.write(value);
} }
} }
fn get_size_hint(&self) -> usize {
let entries_count = self.entries().count();
// Each entry is the size of a digest plus a word.
entries_count.get_size_hint()
+ entries_count * (RpoDigest::SERIALIZED_SIZE + EMPTY_WORD.get_size_hint())
}
} }
impl Deserializable for Smt { impl Deserializable for Smt {
@@ -339,6 +443,7 @@ fn test_smt_serialization_deserialization() {
let smt_default = Smt::default(); let smt_default = Smt::default();
let bytes = smt_default.to_bytes(); let bytes = smt_default.to_bytes();
assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap()); assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap());
assert_eq!(bytes.len(), smt_default.get_size_hint());
// Smt with values // Smt with values
let smt_leaves_2: [(RpoDigest, Word); 2] = [ let smt_leaves_2: [(RpoDigest, Word); 2] = [
@@ -355,4 +460,5 @@ fn test_smt_serialization_deserialization() {
let bytes = smt.to_bytes(); let bytes = smt.to_bytes();
assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap()); assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap());
assert_eq!(bytes.len(), smt.get_size_hint());
} }

View File

@@ -1,6 +1,7 @@
use alloc::string::ToString;
use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH}; use super::{MerklePath, RpoDigest, SmtLeaf, SmtProofError, Word, SMT_DEPTH};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use alloc::string::ToString;
/// A proof which can be used to assert membership (or non-membership) of key-value pairs in a /// A proof which can be used to assert membership (or non-membership) of key-value pairs in a
/// [`super::Smt`]. /// [`super::Smt`].
@@ -24,7 +25,7 @@ impl SmtProof {
pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> { pub fn new(path: MerklePath, leaf: SmtLeaf) -> Result<Self, SmtProofError> {
let depth: usize = SMT_DEPTH.into(); let depth: usize = SMT_DEPTH.into();
if path.len() != depth { if path.len() != depth {
return Err(SmtProofError::InvalidPathLength(path.len())); return Err(SmtProofError::InvalidMerklePathLength(path.len()));
} }
Ok(Self { path, leaf }) Ok(Self { path, leaf })
@@ -57,7 +58,7 @@ impl SmtProof {
// make sure the Merkle path resolves to the correct root // make sure the Merkle path resolves to the correct root
self.compute_root() == *root self.compute_root() == *root
} },
// If the key maps to a different leaf, the proof cannot verify membership of `value` // If the key maps to a different leaf, the proof cannot verify membership of `value`
None => false, None => false,
} }

View File

@@ -1,11 +1,14 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{ use crate::{
merkle::{EmptySubtreeRoots, MerkleStore}, merkle::{
smt::{NodeMutation, SparseMerkleTree},
EmptySubtreeRoots, MerkleStore, MutationSet,
},
utils::{Deserializable, Serializable}, utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE, Word, ONE, WORD_SIZE,
}; };
use alloc::vec::Vec;
// SMT // SMT
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -257,6 +260,297 @@ fn test_smt_removal() {
} }
} }
/// This tests that we can correctly calculate prospective leaves -- that is, we can construct
/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or
/// cloning the tree.
#[test]
fn test_prospective_hash() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
// insert key-value 1
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash();
smt.insert(key_1, value_1);
let leaf = smt.get_leaf(&key_1);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 2
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash();
smt.insert(key_2, value_2);
let leaf = smt.get_leaf(&key_2);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 3
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash();
smt.insert(key_3, value_3);
let leaf = smt.get_leaf(&key_3);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// remove key 3
{
let old_leaf = smt.get_leaf(&key_3);
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
assert_eq!(old_value_3, value_3);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 2
{
let old_leaf = smt.get_leaf(&key_2);
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
assert_eq!(old_value_2, value_2);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 1
{
let old_leaf = smt.get_leaf(&key_1);
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
assert_eq!(old_value_1, value_1);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
}
/// This tests that we can perform prospective changes correctly.
#[test]
fn test_prospective_insertion() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
let root_empty = smt.root();
let root_1 = {
smt.insert(key_1, value_1);
smt.root()
};
let root_2 = {
smt.insert(key_2, value_2);
smt.root()
};
let root_3 = {
smt.insert(key_3, value_3);
smt.root()
};
// Test incremental updates.
let mut smt = Smt::default();
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
assert_eq!(
revert.node_mutations,
smt.inner_nodes.keys().map(|key| (*key, NodeMutation::Removal)).collect(),
"reverse mutations inner nodes did not match"
);
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match"
);
// Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3);
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_3);
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match"
);
// Test batch updates, and that the order doesn't matter.
let pairs =
vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(
mutations.root(),
root_empty,
"prospective root for batch removal did not match actual root",
);
let old_root = smt.root();
let revert = apply_mutations(&mut smt, mutations);
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!(
revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match"
);
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
}
#[test]
fn test_mutations_revert() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let original = smt.clone();
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
assert_eq!(revert.old_root, smt.root(), "reverse mutations old root did not match");
assert_eq!(revert.root(), original.root(), "reverse mutations new root did not match");
smt.apply_mutations(revert).unwrap();
assert_eq!(smt, original, "SMT with applied revert mutations did not match original SMT");
}
#[test]
fn test_mutation_set_serialization() {
let mut smt = Smt::default();
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(1)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(2)]);
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(3)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3 = [3_u32.into(); WORD_SIZE];
smt.insert(key_1, value_1);
smt.insert(key_2, value_2);
let mutations =
smt.compute_mutations(vec![(key_1, EMPTY_WORD), (key_2, value_1), (key_3, value_3)]);
let serialized = mutations.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, mutations, "deserialized mutations did not match original");
let revert = smt.apply_mutations_with_reversion(mutations).unwrap();
let serialized = revert.to_bytes();
let deserialized =
MutationSet::<SMT_DEPTH, RpoDigest, Word>::read_from_bytes(&serialized).unwrap();
assert_eq!(deserialized, revert, "deserialized mutations did not match original");
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path /// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test] #[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() { fn test_smt_path_to_keys_in_same_leaf_are_equal() {
@@ -287,8 +581,7 @@ fn test_empty_leaf_hash() {
#[test] #[test]
fn test_smt_get_value() { fn test_smt_get_value() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]); let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
let value_1 = [ONE; WORD_SIZE]; let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE]; let value_2 = [2_u32.into(); WORD_SIZE];
@@ -302,8 +595,7 @@ fn test_smt_get_value() {
assert_eq!(value_2, returned_value_2); assert_eq!(value_2, returned_value_2);
// Check that a key with no inserted value returns the empty word // Check that a key with no inserted value returns the empty word
let key_no_value = let key_no_value = RpoDigest::from([42_u32, 42_u32, 42_u32, 42_u32]);
RpoDigest::from([42_u32.into(), 42_u32.into(), 42_u32.into(), 42_u32.into()]);
assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value)); assert_eq!(EMPTY_WORD, smt.get_value(&key_no_value));
} }
@@ -312,8 +604,7 @@ fn test_smt_get_value() {
#[test] #[test]
fn test_smt_entries() { fn test_smt_entries() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]); let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), 2_u32.into()]);
let value_1 = [ONE; WORD_SIZE]; let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE]; let value_2 = [2_u32.into(); WORD_SIZE];
@@ -329,6 +620,16 @@ fn test_smt_entries() {
assert!(entries.next().is_none()); assert!(entries.next().is_none());
} }
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of
/// depth 64
#[test]
fn test_smt_check_empty_root_constant() {
// get the root of the empty tree of depth 64
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
assert_eq!(empty_root_64_depth, Smt::EMPTY_ROOT);
}
// SMT LEAF // SMT LEAF
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -347,7 +648,7 @@ fn test_empty_smt_leaf_serialization() {
#[test] #[test]
fn test_single_smt_leaf_serialization() { fn test_single_smt_leaf_serialization() {
let single_leaf = SmtLeaf::new_single( let single_leaf = SmtLeaf::new_single(
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]), RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()], [1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
); );
@@ -363,11 +664,11 @@ fn test_single_smt_leaf_serialization() {
fn test_multiple_smt_leaf_serialization_success() { fn test_multiple_smt_leaf_serialization_success() {
let multiple_leaf = SmtLeaf::new_multiple(vec![ let multiple_leaf = SmtLeaf::new_multiple(vec![
( (
RpoDigest::from([10_u32.into(), 11_u32.into(), 12_u32.into(), 13_u32.into()]), RpoDigest::from([10_u32, 11_u32, 12_u32, 13_u32]),
[1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()], [1_u32.into(), 2_u32.into(), 3_u32.into(), 4_u32.into()],
), ),
( (
RpoDigest::from([100_u32.into(), 101_u32.into(), 102_u32.into(), 13_u32.into()]), RpoDigest::from([100_u32, 101_u32, 102_u32, 13_u32]),
[11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()], [11_u32.into(), 12_u32.into(), 13_u32.into(), 14_u32.into()],
), ),
]) ])
@@ -405,3 +706,19 @@ fn build_multiple_leaf_node(kv_pairs: &[(RpoDigest, Word)]) -> RpoDigest {
Rpo256::hash_elements(&elements) Rpo256::hash_elements(&elements)
} }
/// Applies mutations with and without reversion to the given SMT, comparing resulting SMTs,
/// returning mutation set for reversion.
fn apply_mutations(
smt: &mut Smt,
mutation_set: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
let mut smt2 = smt.clone();
let reversion = smt.apply_mutations_with_reversion(mutation_set.clone()).unwrap();
smt2.apply_mutations(mutation_set).unwrap();
assert_eq!(&smt2, smt);
reversion
}

View File

@@ -1,9 +1,12 @@
use alloc::{collections::BTreeMap, vec::Vec};
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{ use crate::{
hash::rpo::{Rpo256, RpoDigest}, hash::rpo::{Rpo256, RpoDigest},
Felt, Word, EMPTY_WORD, Felt, Word, EMPTY_WORD,
}; };
use alloc::vec::Vec;
mod full; mod full;
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH}; pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
@@ -39,22 +42,25 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// Every key maps to one leaf. If there are as many keys as there are leaves, then /// Every key maps to one leaf. If there are as many keys as there are leaves, then
/// [Self::Leaf] should be the same type as [Self::Value], as is the case with /// [Self::Leaf] should be the same type as [Self::Value], as is the case with
/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`] /// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
/// must accomodate all keys that map to the same leaf. /// must accommodate all keys that map to the same leaf.
/// ///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> { pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key /// The type for a key
type Key: Clone; type Key: Clone + Ord;
/// The type for a value /// The type for a value
type Value: Clone + PartialEq; type Value: Clone + PartialEq;
/// The type for a leaf /// The type for a leaf
type Leaf; type Leaf: Clone;
/// The type for an opening (i.e. a "proof") of a leaf /// The type for an opening (i.e. a "proof") of a leaf
type Opening; type Opening;
/// The default value used to compute the hash of empty leaves /// The default value used to compute the hash of empty leaves
const EMPTY_VALUE: Self::Value; const EMPTY_VALUE: Self::Value;
/// The root of the empty tree with provided DEPTH
const EMPTY_ROOT: RpoDigest;
// PROVIDED METHODS // PROVIDED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
@@ -129,9 +135,9 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
node_hash = Rpo256::merge(&[left, right]); node_hash = Rpo256::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, when can remove the inner node, since it's equal to the // If a subtree is empty, then can remove the inner node, since it's equal to the
// default value // default value
self.remove_inner_node(index) self.remove_inner_node(index);
} else { } else {
self.insert_inner_node(index, InnerNode { left, right }); self.insert_inner_node(index, InnerNode { left, right });
} }
@@ -139,6 +145,226 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
self.set_root(node_hash); self.set_root(node_hash);
} }
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
/// the Merkle tree, or [`drop()`] to discard them.
fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
// to check the key-value pairs we've already seen to get the "effective" old value.
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
// part of the "current leaf".
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
// Most of the time `pairs_at_index` should only contain a single entry (or
// none at all), as multi-leaves should be really rare.
let existing_leaf = acc.clone();
self.construct_prospective_leaf(existing_leaf, k, v)
})
};
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
// Whether the node we're replacing is the right child or the left child.
let is_right = node_index.is_value_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
// The next iteration will operate on this new node's hash.
new_child_hash = new_node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
}
// Once we're at depth 0, the last node we made is the new root.
new_root = new_child_hash;
// And then we're done with this pair; on to the next one.
new_pairs.insert(key, value);
}
MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
}
}
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
for (index, mutation) in node_mutations {
match mutation {
Removal => {
self.remove_inner_node(index);
},
Addition(node) => {
self.insert_inner_node(index, node);
},
}
}
for (key, value) in new_pairs {
self.insert_value(key, value);
}
self.set_root(new_root);
Ok(())
}
/// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
/// updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
let mut reverse_mutations = BTreeMap::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
if let Some(node) = self.remove_inner_node(index) {
reverse_mutations.insert(index, Addition(node));
}
},
Addition(node) => {
if let Some(old_node) = self.insert_inner_node(index, node) {
reverse_mutations.insert(index, Addition(old_node));
} else {
reverse_mutations.insert(index, Removal);
}
},
}
}
let mut reverse_pairs = BTreeMap::new();
for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value);
} else {
reverse_pairs.insert(key, Self::EMPTY_VALUE);
}
}
self.set_root(new_root);
Ok(MutationSet {
old_root: new_root,
node_mutations: reverse_mutations,
new_pairs: reverse_pairs,
new_root: old_root,
})
}
// REQUIRED METHODS // REQUIRED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
@@ -152,20 +378,42 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
fn get_inner_node(&self, index: NodeIndex) -> InnerNode; fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
/// Inserts an inner node at the given index /// Inserts an inner node at the given index
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode); fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
/// Removes an inner node at the given index /// Removes an inner node at the given index
fn remove_inner_node(&mut self, index: NodeIndex); fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
/// Inserts a leaf node, and returns the value at the key if already exists /// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>; fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
/// Returns the value at the specified key. Recall that by definition, any key that hasn't been
/// updated is associated with [`Self::EMPTY_VALUE`].
fn get_value(&self, key: &Self::Key) -> Self::Value;
/// Returns the leaf at the specified index. /// Returns the leaf at the specified index.
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf; fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
/// Returns the hash of a leaf /// Returns the hash of a leaf
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest; fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
/// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
/// mutating the tree itself. The existing leaf can be empty.
///
/// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
/// as the argument for `existing_leaf`. The return value from this function can be chained back
/// into this function as the first argument to continue making prospective changes.
///
/// # Invariants
/// Because this method is for a prospective key-value insertion into a specific leaf,
/// `existing_leaf` must have the same leaf index as `key` (as determined by
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
fn construct_prospective_leaf(
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Self::Leaf;
/// Maps a key to a leaf index /// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>; fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
@@ -180,7 +428,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode { pub struct InnerNode {
pub left: RpoDigest, pub left: RpoDigest,
pub right: RpoDigest, pub right: RpoDigest,
} }
@@ -234,7 +482,7 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> { fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
if node_index.depth() != DEPTH { if node_index.depth() != DEPTH {
return Err(MerkleError::InvalidDepth { return Err(MerkleError::InvalidNodeIndexDepth {
expected: DEPTH, expected: DEPTH,
provided: node_index.depth(), provided: node_index.depth(),
}); });
@@ -243,3 +491,147 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
Self::new(node_index.value()) Self::new(node_index.value())
} }
} }
impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.index.write_into(target);
}
}
impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
Ok(Self { index: source.read()? })
}
}
// MUTATIONS
// ================================================================================================
/// A change to an inner node of a sparse Merkle tree that hasn't yet been applied.
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NodeMutation {
/// Node needs to be removed.
Removal,
/// Node needs to be inserted.
Addition(InnerNode),
}
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct MutationSet<const DEPTH: u8, K, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
old_root: RpoDigest,
/// The set of nodes that need to be removed or added. The "effective" node at an index is the
/// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: BTreeMap<NodeIndex, NodeMutation>,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs: BTreeMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest,
}
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
self.new_root
}
/// Returns the SMT root before the mutations were applied.
pub fn old_root(&self) -> RpoDigest {
self.old_root
}
/// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> {
&self.node_mutations
}
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs(&self) -> &BTreeMap<K, V> {
&self.new_pairs
}
}
// SERIALIZATION
// ================================================================================================
impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target);
self.right.write_into(target);
}
}
impl Deserializable for InnerNode {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let left = source.read()?;
let right = source.read()?;
Ok(Self { left, right })
}
}
impl Serializable for NodeMutation {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
match self {
NodeMutation::Removal => target.write_bool(false),
NodeMutation::Addition(inner_node) => {
target.write_bool(true);
inner_node.write_into(target);
},
}
}
}
impl Deserializable for NodeMutation {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
if source.read_bool()? {
let inner_node = source.read()?;
return Ok(NodeMutation::Addition(inner_node));
}
Ok(NodeMutation::Removal)
}
}
impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?;
Ok(Self {
old_root,
node_mutations,
new_pairs,
new_root,
})
}
}

View File

@@ -1,9 +1,10 @@
use alloc::collections::{BTreeMap, BTreeSet};
use super::{ use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
SMT_MIN_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
}; };
use alloc::collections::{BTreeMap, BTreeSet};
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
@@ -80,7 +81,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
for (idx, (key, value)) in entries.into_iter().enumerate() { for (idx, (key, value)) in entries.into_iter().enumerate() {
if idx >= max_num_entries { if idx >= max_num_entries {
return Err(MerkleError::InvalidNumEntries(max_num_entries)); return Err(MerkleError::TooManyEntries(max_num_entries));
} }
let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value); let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
@@ -157,6 +158,12 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::open(self, key) <Self as SparseMerkleTree<DEPTH>>::open(self, key)
} }
/// Returns a boolean value indicating whether the SMT is empty.
pub fn is_empty(&self) -> bool {
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
self.root == Self::EMPTY_ROOT
}
// ITERATORS // ITERATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@@ -187,6 +194,65 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value) <Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
} }
/// Computes what changes are necessary to insert the specified key-value pairs into this
/// Merkle tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
/// Merkle tree, or [`drop()`] to discard them.
///
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
/// let pair = (LeafIndex::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
}
/// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
/// this tree and returns the reverse mutation set.
///
/// Applying the reverse mutation sets to the updated tree will revert the changes.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`. /// computed as `DEPTH - SUBTREE_DEPTH`.
/// ///
@@ -197,7 +263,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
subtree: SimpleSmt<SUBTREE_DEPTH>, subtree: SimpleSmt<SUBTREE_DEPTH>,
) -> Result<RpoDigest, MerkleError> { ) -> Result<RpoDigest, MerkleError> {
if SUBTREE_DEPTH > DEPTH { if SUBTREE_DEPTH > DEPTH {
return Err(MerkleError::InvalidSubtreeDepth { return Err(MerkleError::SubtreeDepthExceedsDepth {
subtree_depth: SUBTREE_DEPTH, subtree_depth: SUBTREE_DEPTH,
tree_depth: DEPTH, tree_depth: DEPTH,
}); });
@@ -255,6 +321,7 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
type Opening = ValuePath; type Opening = ValuePath;
const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn root(&self) -> RpoDigest { fn root(&self) -> RpoDigest {
self.root self.root
@@ -265,19 +332,18 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
} }
fn get_inner_node(&self, index: NodeIndex) -> InnerNode { fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| { self.inner_nodes
let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1); .get(&index)
.cloned()
InnerNode { left: *node, right: *node } .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
})
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
self.inner_nodes.insert(index, inner_node); self.inner_nodes.insert(index, inner_node)
} }
fn remove_inner_node(&mut self, index: NodeIndex) { fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
let _ = self.inner_nodes.remove(&index); self.inner_nodes.remove(&index)
} }
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> { fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
@@ -288,6 +354,10 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
} }
} }
fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
self.get_leaf(key)
}
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word { fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
let leaf_pos = key.value(); let leaf_pos = key.value();
match self.leaves.get(&leaf_pos) { match self.leaves.get(&leaf_pos) {
@@ -301,6 +371,15 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
leaf.into() leaf.into()
} }
fn construct_prospective_leaf(
&self,
_existing_leaf: Word,
_key: &LeafIndex<DEPTH>,
value: &Word,
) -> Word {
*value
}
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> { fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
*key *key
} }

View File

@@ -1,3 +1,7 @@
use alloc::vec::Vec;
use assert_matches::assert_matches;
use super::{ use super::{
super::{MerkleError, RpoDigest, SimpleSmt}, super::{MerkleError, RpoDigest, SimpleSmt},
NodeIndex, NodeIndex,
@@ -10,7 +14,6 @@ use crate::{
}, },
Word, EMPTY_WORD, Word, EMPTY_WORD,
}; };
use alloc::vec::Vec;
// TEST DATA // TEST DATA
// ================================================================================================ // ================================================================================================
@@ -256,12 +259,12 @@ fn test_simplesmt_fail_on_duplicates() {
// consecutive // consecutive
let entries = [(1, *first), (1, *second)]; let entries = [(1, *first), (1, *second)];
let smt = SimpleSmt::<64>::with_leaves(entries); let smt = SimpleSmt::<64>::with_leaves(entries);
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1)); assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
// not consecutive // not consecutive
let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)]; let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)];
let smt = SimpleSmt::<64>::with_leaves(entries); let smt = SimpleSmt::<64>::with_leaves(entries);
assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1)); assert_matches!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1));
} }
} }
@@ -443,6 +446,23 @@ fn test_simplesmt_set_subtree_entire_tree() {
assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0)); assert_eq!(tree.root(), *EmptySubtreeRoots::entry(DEPTH, 0));
} }
/// Tests that `EMPTY_ROOT` constant generated in the `SimpleSmt` equals to the root of the empty
/// tree of depth 64
#[test]
fn test_simplesmt_check_empty_root_constant() {
// get the root of the empty tree of depth 64
let empty_root_64_depth = EmptySubtreeRoots::empty_hashes(64)[0];
assert_eq!(empty_root_64_depth, SimpleSmt::<64>::EMPTY_ROOT);
// get the root of the empty tree of depth 32
let empty_root_32_depth = EmptySubtreeRoots::empty_hashes(32)[0];
assert_eq!(empty_root_32_depth, SimpleSmt::<32>::EMPTY_ROOT);
// get the root of the empty tree of depth 0
let empty_root_1_depth = EmptySubtreeRoots::empty_hashes(1)[0];
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
}
// HELPER FUNCTIONS // HELPER FUNCTIONS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------

View File

@@ -127,8 +127,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
/// # Errors /// # Errors
/// This method can return the following errors: /// This method can return the following errors:
/// - `RootNotInStore` if the `root` is not present in the store. /// - `RootNotInStore` if the `root` is not present in the store.
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in /// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
/// the store. /// store.
pub fn get_node(&self, root: RpoDigest, index: NodeIndex) -> Result<RpoDigest, MerkleError> { pub fn get_node(&self, root: RpoDigest, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
let mut hash = root; let mut hash = root;
@@ -136,7 +136,10 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?; self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
for i in (0..index.depth()).rev() { for i in (0..index.depth()).rev() {
let node = self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash, index))?; let node = self
.nodes
.get(&hash)
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
let bit = (index.value() >> i) & 1; let bit = (index.value() >> i) & 1;
hash = if bit == 0 { node.left } else { node.right } hash = if bit == 0 { node.left } else { node.right }
@@ -152,8 +155,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
/// # Errors /// # Errors
/// This method can return the following errors: /// This method can return the following errors:
/// - `RootNotInStore` if the `root` is not present in the store. /// - `RootNotInStore` if the `root` is not present in the store.
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in /// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
/// the store. /// store.
pub fn get_path(&self, root: RpoDigest, index: NodeIndex) -> Result<ValuePath, MerkleError> { pub fn get_path(&self, root: RpoDigest, index: NodeIndex) -> Result<ValuePath, MerkleError> {
let mut hash = root; let mut hash = root;
let mut path = Vec::with_capacity(index.depth().into()); let mut path = Vec::with_capacity(index.depth().into());
@@ -162,7 +165,10 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?; self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
for i in (0..index.depth()).rev() { for i in (0..index.depth()).rev() {
let node = self.nodes.get(&hash).ok_or(MerkleError::NodeNotInStore(hash, index))?; let node = self
.nodes
.get(&hash)
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
let bit = (index.value() >> i) & 1; let bit = (index.value() >> i) & 1;
hash = if bit == 0 { hash = if bit == 0 {
@@ -421,8 +427,8 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
/// # Errors /// # Errors
/// This method can return the following errors: /// This method can return the following errors:
/// - `RootNotInStore` if the `root` is not present in the store. /// - `RootNotInStore` if the `root` is not present in the store.
/// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in /// - `NodeNotInStore` if a node needed to traverse from `root` to `index` is not present in the
/// the store. /// store.
pub fn set_node( pub fn set_node(
&mut self, &mut self,
mut root: RpoDigest, mut root: RpoDigest,

View File

@@ -1,4 +1,11 @@
use assert_matches::assert_matches;
use seq_macro::seq; use seq_macro::seq;
#[cfg(feature = "std")]
use {
super::{Deserializable, Serializable},
alloc::boxed::Box,
std::error::Error,
};
use super::{ use super::{
DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex,
@@ -11,13 +18,6 @@ use crate::{
Felt, Word, ONE, WORD_SIZE, ZERO, Felt, Word, ONE, WORD_SIZE, ZERO,
}; };
#[cfg(feature = "std")]
use {
super::{Deserializable, Serializable},
alloc::boxed::Box,
std::error::Error,
};
// TEST DATA // TEST DATA
// ================================================================================================ // ================================================================================================
@@ -43,14 +43,14 @@ const VALUES8: [RpoDigest; 8] = [
fn test_root_not_in_store() -> Result<(), MerkleError> { fn test_root_not_in_store() -> Result<(), MerkleError> {
let mtree = MerkleTree::new(digests_to_words(&VALUES4))?; let mtree = MerkleTree::new(digests_to_words(&VALUES4))?;
let store = MerkleStore::from(&mtree); let store = MerkleStore::from(&mtree);
assert_eq!( assert_matches!(
store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)), store.get_node(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
Err(MerkleError::RootNotInStore(VALUES4[0])), Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
"Leaf 0 is not a root" "Leaf 0 is not a root"
); );
assert_eq!( assert_matches!(
store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)), store.get_path(VALUES4[0], NodeIndex::make(mtree.depth(), 0)),
Err(MerkleError::RootNotInStore(VALUES4[0])), Err(MerkleError::RootNotInStore(root)) if root == VALUES4[0],
"Leaf 0 is not a root" "Leaf 0 is not a root"
); );
@@ -65,46 +65,46 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
// STORE LEAVES ARE CORRECT ------------------------------------------------------------------- // STORE LEAVES ARE CORRECT -------------------------------------------------------------------
// checks the leaves in the store corresponds to the expected values // checks the leaves in the store corresponds to the expected values
assert_eq!( assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
Ok(VALUES4[0]), VALUES4[0],
"node 0 must be in the tree" "node 0 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
Ok(VALUES4[1]), VALUES4[1],
"node 1 must be in the tree" "node 1 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
Ok(VALUES4[2]), VALUES4[2],
"node 2 must be in the tree" "node 2 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
Ok(VALUES4[3]), VALUES4[3],
"node 3 must be in the tree" "node 3 must be in the tree"
); );
// STORE LEAVES MATCH TREE -------------------------------------------------------------------- // STORE LEAVES MATCH TREE --------------------------------------------------------------------
// sanity check the values returned by the store and the tree // sanity check the values returned by the store and the tree
assert_eq!( assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 0)), mtree.get_node(NodeIndex::make(mtree.depth(), 0)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 0)).unwrap(),
"node 0 must be the same for both MerkleTree and MerkleStore" "node 0 must be the same for both MerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 1)), mtree.get_node(NodeIndex::make(mtree.depth(), 1)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 1)).unwrap(),
"node 1 must be the same for both MerkleTree and MerkleStore" "node 1 must be the same for both MerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 2)), mtree.get_node(NodeIndex::make(mtree.depth(), 2)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 2)).unwrap(),
"node 2 must be the same for both MerkleTree and MerkleStore" "node 2 must be the same for both MerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
mtree.get_node(NodeIndex::make(mtree.depth(), 3)), mtree.get_node(NodeIndex::make(mtree.depth(), 3)).unwrap(),
store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)), store.get_node(mtree.root(), NodeIndex::make(mtree.depth(), 3)).unwrap(),
"node 3 must be the same for both MerkleTree and MerkleStore" "node 3 must be the same for both MerkleTree and MerkleStore"
); );
@@ -116,8 +116,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 0)), mtree.get_path(NodeIndex::make(mtree.depth(), 0)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore" "merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
); );
@@ -127,8 +127,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 1)), mtree.get_path(NodeIndex::make(mtree.depth(), 1)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore" "merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
); );
@@ -138,8 +138,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 2)), mtree.get_path(NodeIndex::make(mtree.depth(), 2)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore" "merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
); );
@@ -149,8 +149,8 @@ fn test_merkle_tree() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
mtree.get_path(NodeIndex::make(mtree.depth(), 3)), mtree.get_path(NodeIndex::make(mtree.depth(), 3)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore" "merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
); );
@@ -241,56 +241,56 @@ fn test_sparse_merkle_tree() -> Result<(), MerkleError> {
// STORE LEAVES ARE CORRECT ============================================================== // STORE LEAVES ARE CORRECT ==============================================================
// checks the leaves in the store corresponds to the expected values // checks the leaves in the store corresponds to the expected values
assert_eq!( assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
Ok(VALUES4[0]), VALUES4[0],
"node 0 must be in the tree" "node 0 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
Ok(VALUES4[1]), VALUES4[1],
"node 1 must be in the tree" "node 1 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
Ok(VALUES4[2]), VALUES4[2],
"node 2 must be in the tree" "node 2 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
Ok(VALUES4[3]), VALUES4[3],
"node 3 must be in the tree" "node 3 must be in the tree"
); );
assert_eq!( assert_eq!(
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
Ok(RpoDigest::default()), RpoDigest::default(),
"unmodified node 4 must be ZERO" "unmodified node 4 must be ZERO"
); );
// STORE LEAVES MATCH TREE =============================================================== // STORE LEAVES MATCH TREE ===============================================================
// sanity check the values returned by the store and the tree // sanity check the values returned by the store and the tree
assert_eq!( assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)), smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 0)).unwrap(),
"node 0 must be the same for both SparseMerkleTree and MerkleStore" "node 0 must be the same for both SparseMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)), smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 1)).unwrap(),
"node 1 must be the same for both SparseMerkleTree and MerkleStore" "node 1 must be the same for both SparseMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)), smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 2)).unwrap(),
"node 2 must be the same for both SparseMerkleTree and MerkleStore" "node 2 must be the same for both SparseMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)), smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 3)).unwrap(),
"node 3 must be the same for both SparseMerkleTree and MerkleStore" "node 3 must be the same for both SparseMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)), smt.get_node(NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)), store.get_node(smt.root(), NodeIndex::make(SMT_MAX_DEPTH, 4)).unwrap(),
"node 4 must be the same for both SparseMerkleTree and MerkleStore" "node 4 must be the same for both SparseMerkleTree and MerkleStore"
); );
@@ -386,46 +386,46 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
// STORE LEAVES ARE CORRECT ============================================================== // STORE LEAVES ARE CORRECT ==============================================================
// checks the leaves in the store corresponds to the expected values // checks the leaves in the store corresponds to the expected values
assert_eq!( assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
Ok(VALUES4[0]), VALUES4[0],
"node 0 must be in the pmt" "node 0 must be in the pmt"
); );
assert_eq!( assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
Ok(VALUES4[1]), VALUES4[1],
"node 1 must be in the pmt" "node 1 must be in the pmt"
); );
assert_eq!( assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
Ok(VALUES4[2]), VALUES4[2],
"node 2 must be in the pmt" "node 2 must be in the pmt"
); );
assert_eq!( assert_eq!(
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
Ok(VALUES4[3]), VALUES4[3],
"node 3 must be in the pmt" "node 3 must be in the pmt"
); );
// STORE LEAVES MATCH PMT ================================================================ // STORE LEAVES MATCH PMT ================================================================
// sanity check the values returned by the store and the pmt // sanity check the values returned by the store and the pmt
assert_eq!( assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)), pmt.get_node(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
"node 0 must be the same for both PartialMerkleTree and MerkleStore" "node 0 must be the same for both PartialMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)), pmt.get_node(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
"node 1 must be the same for both PartialMerkleTree and MerkleStore" "node 1 must be the same for both PartialMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)), pmt.get_node(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
"node 2 must be the same for both PartialMerkleTree and MerkleStore" "node 2 must be the same for both PartialMerkleTree and MerkleStore"
); );
assert_eq!( assert_eq!(
pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)), pmt.get_node(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)), store.get_node(pmt.root(), NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
"node 3 must be the same for both PartialMerkleTree and MerkleStore" "node 3 must be the same for both PartialMerkleTree and MerkleStore"
); );
@@ -437,8 +437,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)), pmt.get_path(NodeIndex::make(pmt.max_depth(), 0)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore" "merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
); );
@@ -448,8 +448,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)), pmt.get_path(NodeIndex::make(pmt.max_depth(), 1)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 1 must be the same for the MerkleTree and MerkleStore" "merkle path for index 1 must be the same for the MerkleTree and MerkleStore"
); );
@@ -459,8 +459,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)), pmt.get_path(NodeIndex::make(pmt.max_depth(), 2)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore" "merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
); );
@@ -470,8 +470,8 @@ fn test_add_merkle_paths() -> Result<(), MerkleError> {
"Value for merkle path at index 0 must match leaf value" "Value for merkle path at index 0 must match leaf value"
); );
assert_eq!( assert_eq!(
pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)), pmt.get_path(NodeIndex::make(pmt.max_depth(), 3)).unwrap(),
Ok(result.path), result.path,
"merkle path for index 0 must be the same for the MerkleTree and MerkleStore" "merkle path for index 0 must be the same for the MerkleTree and MerkleStore"
); );
@@ -499,7 +499,7 @@ fn wont_open_to_different_depth_root() {
let store = MerkleStore::from(&mtree); let store = MerkleStore::from(&mtree);
let index = NodeIndex::root(); let index = NodeIndex::root();
let err = store.get_node(root, index).err().unwrap(); let err = store.get_node(root, index).err().unwrap();
assert_eq!(err, MerkleError::RootNotInStore(root)); assert_matches!(err, MerkleError::RootNotInStore(err_root) if err_root == root);
} }
#[test] #[test]
@@ -538,7 +538,7 @@ fn test_set_node() -> Result<(), MerkleError> {
let value = int_to_node(42); let value = int_to_node(42);
let index = NodeIndex::make(mtree.depth(), 0); let index = NodeIndex::make(mtree.depth(), 0);
let new_root = store.set_node(mtree.root(), index, value)?.root; let new_root = store.set_node(mtree.root(), index, value)?.root;
assert_eq!(store.get_node(new_root, index), Ok(value), "Value must have changed"); assert_eq!(store.get_node(new_root, index).unwrap(), value, "value must have changed");
Ok(()) Ok(())
} }
@@ -614,7 +614,7 @@ fn node_path_should_be_truncated_by_midtier_insert() {
let path = store.get_path(root, index).unwrap().path; let path = store.get_path(root, index).unwrap().path;
assert_eq!(node, result); assert_eq!(node, result);
assert_eq!(path.depth(), depth); assert_eq!(path.depth(), depth);
assert!(path.verify(index.value(), result, &root)); assert!(path.verify(index.value(), result, &root).is_ok());
// flip the first bit of the key and insert the second node on a different depth // flip the first bit of the key and insert the second node on a different depth
let key = key ^ (1 << 63); let key = key ^ (1 << 63);
@@ -627,7 +627,7 @@ fn node_path_should_be_truncated_by_midtier_insert() {
let path = store.get_path(root, index).unwrap().path; let path = store.get_path(root, index).unwrap().path;
assert_eq!(node, result); assert_eq!(node, result);
assert_eq!(path.depth(), depth); assert_eq!(path.depth(), depth);
assert!(path.verify(index.value(), result, &root)); assert!(path.verify(index.value(), result, &root).is_ok());
// attempt to fetch a path of the second node to depth 64 // attempt to fetch a path of the second node to depth 64
// should fail because the previously inserted node will remove its sub-tree from the set // should fail because the previously inserted node will remove its sub-tree from the set
@@ -725,7 +725,7 @@ fn get_leaf_depth_works_with_depth_8() {
assert_eq!(8, store.get_leaf_depth(root, 8, k).unwrap()); assert_eq!(8, store.get_leaf_depth(root, 8, k).unwrap());
} }
// flip last bit of a and expect it to return the the same depth, but for an empty node // flip last bit of a and expect it to return the same depth, but for an empty node
assert_eq!(8, store.get_leaf_depth(root, 8, 0b01101000_u64).unwrap()); assert_eq!(8, store.get_leaf_depth(root, 8, 0b01101000_u64).unwrap());
// flip fourth bit of a and expect an empty node on depth 4 // flip fourth bit of a and expect an empty node on depth 4
@@ -746,7 +746,7 @@ fn get_leaf_depth_works_with_depth_8() {
// duplicate the tree on `a` and assert the depth is short-circuited by such sub-tree // duplicate the tree on `a` and assert the depth is short-circuited by such sub-tree
let index = NodeIndex::new(8, a).unwrap(); let index = NodeIndex::new(8, a).unwrap();
root = store.set_node(root, index, root).unwrap().root; root = store.set_node(root, index, root).unwrap().root;
assert_eq!(Err(MerkleError::DepthTooBig(9)), store.get_leaf_depth(root, 8, a)); assert_matches!(store.get_leaf_depth(root, 8, a).unwrap_err(), MerkleError::DepthTooBig(9));
} }
#[test] #[test]

View File

@@ -7,7 +7,9 @@ pub use winter_utils::Randomizable;
use crate::{Felt, FieldElement, Word, ZERO}; use crate::{Felt, FieldElement, Word, ZERO};
mod rpo; mod rpo;
mod rpx;
pub use rpo::RpoRandomCoin; pub use rpo::RpoRandomCoin;
pub use rpx::RpxRandomCoin;
/// Pseudo-random element generator. /// Pseudo-random element generator.
/// ///

View File

@@ -1,10 +1,12 @@
use alloc::{string::ToString, vec::Vec};
use rand_core::impls;
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO}; use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
use crate::{ use crate::{
hash::rpo::{Rpo256, RpoDigest}, hash::rpo::{Rpo256, RpoDigest},
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}, utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
}; };
use alloc::{string::ToString, vec::Vec};
use rand_core::impls;
// CONSTANTS // CONSTANTS
// ================================================================================================ // ================================================================================================
@@ -20,8 +22,8 @@ const HALF_RATE_WIDTH: usize = (Rpo256::RATE_RANGE.end - Rpo256::RATE_RANGE.star
/// described in <https://eprint.iacr.org/2011/499.pdf>. /// described in <https://eprint.iacr.org/2011/499.pdf>.
/// ///
/// The simplification is related to the following facts: /// The simplification is related to the following facts:
/// 1. A call to the reseed method implies one and only one call to the permutation function. /// 1. A call to the reseed method implies one and only one call to the permutation function. This
/// This is possible because in our case we never reseed with more than 4 field elements. /// is possible because in our case we never reseed with more than 4 field elements.
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed /// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
/// material. /// material.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -143,8 +145,10 @@ impl RandomCoin for RpoRandomCoin {
self.state[RATE_START] += nonce; self.state[RATE_START] += nonce;
Rpo256::apply_permutation(&mut self.state); Rpo256::apply_permutation(&mut self.state);
// reset the buffer // reset the buffer and move the next random element pointer to the second rate element.
self.current = RATE_START; // this is done as the first rate element will be "biased" via the provided `nonce` to
// contain some number of leading zeros.
self.current = RATE_START + 1;
// determine how many bits are needed to represent valid values in the domain // determine how many bits are needed to represent valid values in the domain
let v_mask = (domain_size - 1) as u64; let v_mask = (domain_size - 1) as u64;

294
src/rand/rpx.rs Normal file
View File

@@ -0,0 +1,294 @@
use alloc::{string::ToString, vec::Vec};
use rand_core::impls;
use super::{Felt, FeltRng, FieldElement, RandomCoin, RandomCoinError, RngCore, Word, ZERO};
use crate::{
hash::rpx::{Rpx256, RpxDigest},
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
// CONSTANTS
// ================================================================================================
const STATE_WIDTH: usize = Rpx256::STATE_WIDTH;
const RATE_START: usize = Rpx256::RATE_RANGE.start;
const RATE_END: usize = Rpx256::RATE_RANGE.end;
const HALF_RATE_WIDTH: usize = (Rpx256::RATE_RANGE.end - Rpx256::RATE_RANGE.start) / 2;
// RPX RANDOM COIN
// ================================================================================================
/// A simplified version of the `SPONGE_PRG` reseedable pseudo-random number generator algorithm
/// described in <https://eprint.iacr.org/2011/499.pdf>.
///
/// The simplification is related to the following facts:
/// 1. A call to the reseed method implies one and only one call to the permutation function. This
/// is possible because in our case we never reseed with more than 4 field elements.
/// 2. As a result of the previous point, we don't make use of an input buffer to accumulate seed
/// material.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RpxRandomCoin {
state: [Felt; STATE_WIDTH],
current: usize,
}
impl RpxRandomCoin {
/// Returns a new [RpxRandomCoin] initialize with the specified seed.
pub fn new(seed: Word) -> Self {
let mut state = [ZERO; STATE_WIDTH];
for i in 0..HALF_RATE_WIDTH {
state[RATE_START + i] += seed[i];
}
// Absorb
Rpx256::apply_permutation(&mut state);
RpxRandomCoin { state, current: RATE_START }
}
/// Returns an [RpxRandomCoin] instantiated from the provided components.
///
/// # Panics
/// Panics if `current` is smaller than 4 or greater than or equal to 12.
pub fn from_parts(state: [Felt; STATE_WIDTH], current: usize) -> Self {
assert!(
(RATE_START..RATE_END).contains(&current),
"current value outside of valid range"
);
Self { state, current }
}
/// Returns components of this random coin.
pub fn into_parts(self) -> ([Felt; STATE_WIDTH], usize) {
(self.state, self.current)
}
/// Fills `dest` with random data.
pub fn fill_bytes(&mut self, dest: &mut [u8]) {
<Self as RngCore>::fill_bytes(self, dest)
}
fn draw_basefield(&mut self) -> Felt {
if self.current == RATE_END {
Rpx256::apply_permutation(&mut self.state);
self.current = RATE_START;
}
self.current += 1;
self.state[self.current - 1]
}
}
// RANDOM COIN IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl RandomCoin for RpxRandomCoin {
type BaseField = Felt;
type Hasher = Rpx256;
fn new(seed: &[Self::BaseField]) -> Self {
let digest: Word = Rpx256::hash_elements(seed).into();
Self::new(digest)
}
fn reseed(&mut self, data: RpxDigest) {
// Reset buffer
self.current = RATE_START;
// Add the new seed material to the first half of the rate portion of the RPX state
let data: Word = data.into();
self.state[RATE_START] += data[0];
self.state[RATE_START + 1] += data[1];
self.state[RATE_START + 2] += data[2];
self.state[RATE_START + 3] += data[3];
// Absorb
Rpx256::apply_permutation(&mut self.state);
}
fn check_leading_zeros(&self, value: u64) -> u32 {
let value = Felt::new(value);
let mut state_tmp = self.state;
state_tmp[RATE_START] += value;
Rpx256::apply_permutation(&mut state_tmp);
let first_rate_element = state_tmp[RATE_START].as_int();
first_rate_element.trailing_zeros()
}
fn draw<E: FieldElement<BaseField = Felt>>(&mut self) -> Result<E, RandomCoinError> {
let ext_degree = E::EXTENSION_DEGREE;
let mut result = vec![ZERO; ext_degree];
for r in result.iter_mut().take(ext_degree) {
*r = self.draw_basefield();
}
let result = E::slice_from_base_elements(&result);
Ok(result[0])
}
fn draw_integers(
&mut self,
num_values: usize,
domain_size: usize,
nonce: u64,
) -> Result<Vec<usize>, RandomCoinError> {
assert!(domain_size.is_power_of_two(), "domain size must be a power of two");
assert!(num_values < domain_size, "number of values must be smaller than domain size");
// absorb the nonce
let nonce = Felt::new(nonce);
self.state[RATE_START] += nonce;
Rpx256::apply_permutation(&mut self.state);
// reset the buffer
self.current = RATE_START;
// determine how many bits are needed to represent valid values in the domain
let v_mask = (domain_size - 1) as u64;
// draw values from PRNG until we get as many unique values as specified by num_queries
let mut values = Vec::new();
for _ in 0..1000 {
// get the next pseudo-random field element
let value = self.draw_basefield().as_int();
// use the mask to get a value within the range
let value = (value & v_mask) as usize;
values.push(value);
if values.len() == num_values {
break;
}
}
if values.len() < num_values {
return Err(RandomCoinError::FailedToDrawIntegers(num_values, values.len(), 1000));
}
Ok(values)
}
}
// FELT RNG IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl FeltRng for RpxRandomCoin {
fn draw_element(&mut self) -> Felt {
self.draw_basefield()
}
fn draw_word(&mut self) -> Word {
let mut output = [ZERO; 4];
for o in output.iter_mut() {
*o = self.draw_basefield();
}
output
}
}
// RNGCORE IMPLEMENTATION
// ------------------------------------------------------------------------------------------------
impl RngCore for RpxRandomCoin {
fn next_u32(&mut self) -> u32 {
self.draw_basefield().as_int() as u32
}
fn next_u64(&mut self) -> u64 {
impls::next_u64_via_u32(self)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
impls::fill_bytes_via_next(self, dest)
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
self.fill_bytes(dest);
Ok(())
}
}
// SERIALIZATION
// ------------------------------------------------------------------------------------------------
impl Serializable for RpxRandomCoin {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.state.iter().for_each(|v| v.write_into(target));
// casting to u8 is OK because `current` is always between 4 and 12.
target.write_u8(self.current as u8);
}
}
impl Deserializable for RpxRandomCoin {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let state = [
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
Felt::read_from(source)?,
];
let current = source.read_u8()? as usize;
if !(RATE_START..RATE_END).contains(&current) {
return Err(DeserializationError::InvalidValue(
"current value outside of valid range".to_string(),
));
}
Ok(Self { state, current })
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{Deserializable, FeltRng, RpxRandomCoin, Serializable, ZERO};
use crate::ONE;
#[test]
fn test_feltrng_felt() {
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
let output = rpxcoin.draw_element();
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
let expected = rpxcoin.draw_basefield();
assert_eq!(output, expected);
}
#[test]
fn test_feltrng_word() {
let mut rpxcoin = RpxRandomCoin::new([ZERO; 4]);
let output = rpxcoin.draw_word();
let mut rpocoin = RpxRandomCoin::new([ZERO; 4]);
let mut expected = [ZERO; 4];
for o in expected.iter_mut() {
*o = rpocoin.draw_basefield();
}
assert_eq!(output, expected);
}
#[test]
fn test_feltrng_serialization() {
let coin1 = RpxRandomCoin::from_parts([ONE; 12], 5);
let bytes = coin1.to_bytes();
let coin2 = RpxRandomCoin::read_from_bytes(&bytes).unwrap();
assert_eq!(coin1, coin2);
}
}

View File

@@ -62,8 +62,8 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for BTreeMap<K, V> {
/// The [RecordingMap] is composed of three parts: /// The [RecordingMap] is composed of three parts:
/// - `data`: which contains the current set of key-value pairs in the map. /// - `data`: which contains the current set of key-value pairs in the map.
/// - `updates`: which tracks keys for which values have been changed since the map was /// - `updates`: which tracks keys for which values have been changed since the map was
/// instantiated. updates include both insertions, removals and updates of values under existing /// instantiated. updates include both insertions, removals and updates of values under existing
/// keys. /// keys.
/// - `trace`: which contains the key-value pairs from the original data which have been accesses /// - `trace`: which contains the key-value pairs from the original data which have been accesses
/// since the map was instantiated. /// since the map was instantiated.
#[derive(Debug, Default, Clone, Eq, PartialEq)] #[derive(Debug, Default, Clone, Eq, PartialEq)]
@@ -126,11 +126,10 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
/// ///
/// If the key is part of the initial data set, the key access is recorded. /// If the key is part of the initial data set, the key access is recorded.
fn get(&self, key: &K) -> Option<&V> { fn get(&self, key: &K) -> Option<&V> {
self.data.get(key).map(|value| { self.data.get(key).inspect(|&value| {
if !self.updates.contains(key) { if !self.updates.contains(key) {
self.trace.borrow_mut().insert(key.clone(), value.clone()); self.trace.borrow_mut().insert(key.clone(), value.clone());
} }
value
}) })
} }
@@ -155,11 +154,10 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
/// returned. /// returned.
fn insert(&mut self, key: K, value: V) -> Option<V> { fn insert(&mut self, key: K, value: V) -> Option<V> {
let new_update = self.updates.insert(key.clone()); let new_update = self.updates.insert(key.clone());
self.data.insert(key.clone(), value).map(|old_value| { self.data.insert(key.clone(), value).inspect(|old_value| {
if new_update { if new_update {
self.trace.borrow_mut().insert(key, old_value.clone()); self.trace.borrow_mut().insert(key, old_value.clone());
} }
old_value
}) })
} }
@@ -167,12 +165,11 @@ impl<K: Ord + Clone, V: Clone> KvMap<K, V> for RecordingMap<K, V> {
/// ///
/// If the key exists in the data set, the old value is returned. /// If the key exists in the data set, the old value is returned.
fn remove(&mut self, key: &K) -> Option<V> { fn remove(&mut self, key: &K) -> Option<V> {
self.data.remove(key).map(|old_value| { self.data.remove(key).inspect(|old_value| {
let new_update = self.updates.insert(key.clone()); let new_update = self.updates.insert(key.clone());
if new_update { if new_update {
self.trace.borrow_mut().insert(key.clone(), old_value.clone()); self.trace.borrow_mut().insert(key.clone(), old_value.clone());
} }
old_value
}) })
} }
@@ -328,7 +325,8 @@ mod tests {
let mut map = RecordingMap::new(ITEMS.to_vec()); let mut map = RecordingMap::new(ITEMS.to_vec());
assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y)))); assert!(map.iter().all(|(x, y)| ITEMS.contains(&(*x, *y))));
// when inserting entry with key that already exists the iterator should return the new value // when inserting entry with key that already exists the iterator should return the new
// value
let new_value = 5; let new_value = 5;
map.insert(4, new_value); map.insert(4, new_value);
assert_eq!(map.iter().count(), ITEMS.len()); assert_eq!(map.iter().count(), ITEMS.len());

View File

@@ -1,7 +1,9 @@
//! Utilities used in this crate which can also be generally useful downstream. //! Utilities used in this crate which can also be generally useful downstream.
use alloc::string::String; use alloc::string::String;
use core::fmt::{self, Display, Write}; use core::fmt::{self, Write};
use thiserror::Error;
use super::Word; use super::Word;
@@ -46,36 +48,20 @@ pub fn bytes_to_hex_string<const N: usize>(data: [u8; N]) -> String {
} }
/// Defines errors which can occur during parsing of hexadecimal strings. /// Defines errors which can occur during parsing of hexadecimal strings.
#[derive(Debug)] #[derive(Debug, Error)]
pub enum HexParseError { pub enum HexParseError {
#[error(
"expected hex data to have length {expected}, including the 0x prefix, found {actual}"
)]
InvalidLength { expected: usize, actual: usize }, InvalidLength { expected: usize, actual: usize },
#[error("hex encoded data must start with 0x prefix")]
MissingPrefix, MissingPrefix,
#[error("hex encoded data must contain only characters [a-zA-Z0-9]")]
InvalidChar, InvalidChar,
#[error("hex encoded values of a Digest must be inside the field modulus")]
OutOfRange, OutOfRange,
} }
impl Display for HexParseError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
HexParseError::InvalidLength { expected, actual } => {
write!(f, "Hex encoded RpoDigest must have length 66, including the 0x prefix. expected {expected} got {actual}")
}
HexParseError::MissingPrefix => {
write!(f, "Hex encoded RpoDigest must start with 0x prefix")
}
HexParseError::InvalidChar => {
write!(f, "Hex encoded RpoDigest must contain characters [a-zA-Z0-9]")
}
HexParseError::OutOfRange => {
write!(f, "Hex encoded values of an RpoDigest must be inside the field modulus")
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for HexParseError {}
/// Parses a hex string into an array of bytes of known size. /// Parses a hex string into an array of bytes of known size.
pub fn hex_to_bytes<const N: usize>(value: &str) -> Result<[u8; N], HexParseError> { pub fn hex_to_bytes<const N: usize>(value: &str) -> Result<[u8; N], HexParseError> {
let expected: usize = (N * 2) + 2; let expected: usize = (N * 2) + 2;