148 Commits

Author SHA1 Message Date
Al-Kindi-0
d2a6739605 feat: derandomize RPO-STARK DSA (#358) 2025-01-08 11:42:23 -08:00
Al-Kindi-0
cae87a2790 chore: add signatures benchmarks (#354) 2024-12-12 19:58:33 -08:00
Al-Kindi-0
335c50f54d feat: implement RPO STARK-based signature DSA (with zero knowledge) (#349) 2024-12-12 19:33:24 -08:00
Qyriad
b151773b0d feat: implement concurrent Smt construction (#341)
* merkle: add parent() helper function on NodeIndex
* smt: add pairs_to_leaf() to trait
* smt: add sorted_pairs_to_leaves() and test for it
* smt: implement single subtree-8 hashing, w/ benchmarks & tests

This will be composed into depth-8-subtree-based computation of entire
sparse Merkle trees.

* merkle: add a benchmark for constructing 256-balanced trees

This is intended for comparison with the benchmarks from the previous
commit. This benchmark represents the theoretical perfect-efficiency
performance we could possibly (but impractically) get for computing
depth-8 sparse Merkle subtrees.

* smt: test that SparseMerkleTree::build_subtree() is composable

* smt: test that subtree logic can correctly construct an entire tree

This commit ensures that `SparseMerkleTree::build_subtree()` can
correctly compose into building an entire sparse Merkle tree, without
yet getting into potential complications concurrency introduces.

* smt: implement test for basic parallelized subtree computation w/ rayon

Building on the previous commit, this commit implements a test proving
that `SparseMerkleTree::build_subtree()` can be composed into itself not
just concurrently, but in parallel, without issue.

* smt: add from_raw_parts() to trait interface

This commit adds a new required method to the SparseMerkleTree trait,
to allow generic construction from pre-computed parts.

This will be used to add a generic version of `with_entries()` in a
later commit.

* smt: add parallel constructors to Smt and SimpleSmt

What the previous few commits have been leading up to: SparseMerkleTree
now has a function to construct the tree from existing data in parallel.
This is significantly faster than the singlethreaded equivalent.
Benchmarks incoming!

---------

Co-authored-by: krushimir <krushimir@reilabs.co>
Co-authored-by: krushimir <kresimir.grofelnik@reilabs.io>
2024-12-04 10:54:41 -08:00
Bobbin Threadbare
1867f842d3 chore: update changelog 2024-11-24 22:26:51 -08:00
Al-Kindi-0
e1072ecc7f chore: update to winterfell dependencies to 0.11 (#346) 2024-11-24 22:20:19 -08:00
Bobbin Threadbare
063ad49afd chore: update crate version to v0.13.0 2024-11-21 15:56:55 -08:00
Philipp Gackstatter
a27f9ad828 refactor: use thiserror to derive errors and update error messages (#344) 2024-11-21 15:52:20 -08:00
Al-Kindi-0
50dd6bda19 fix: skip using the field element containing the proof-of-work (#343) 2024-11-18 00:16:27 -08:00
Bobbin Threadbare
ee20a49953 chore: increment crate version to v0.12.0 and update changelog 2024-10-30 15:04:08 -07:00
Al-Kindi-0
0d75e3593b chore: migrate to Winterfell v0.10.0 release (#338) 2024-10-29 15:02:46 -07:00
Bobbin Threadbare
689cc93ed1 chore: update crate version to v0.11.0 and set MSRV to 1.82 2024-10-17 23:16:41 -07:00
Bobbin Threadbare
7970d3a736 Merge branch 'main' into next 2024-10-17 20:53:09 -07:00
Al-Kindi-0
a734dace1e feat: update RPO's padding rule to use that in the xHash paper (#318) 2024-10-17 20:49:44 -07:00
Andrey Khmuro
940cc04670 feat: add Smt::is_empty (#337) 2024-10-17 14:27:50 -07:00
Andrey Khmuro
e82baa35bb feat: return error instead of panic during MMR verification (#335) 2024-10-17 07:23:29 -07:00
Bobbin Threadbare
876d1bf97a chore: update crate version v0.10.3 2024-09-26 09:37:34 -07:00
Philipp Gackstatter
8adc0ab418 feat: implement get_size_hint for Smt (#331) 2024-09-26 09:13:50 -07:00
Bobbin Threadbare
c2eb38c236 chore: increment crate version to v0.10.2 2024-09-25 03:05:33 -07:00
Philipp Gackstatter
a924ac6b81 feat: Add size hint for digests (#330) 2024-09-25 03:03:31 -07:00
Bobbin Threadbare
e214608c85 fix: bug introduced due to merging 2024-09-13 11:10:34 -07:00
Bobbin Threadbare
c44ccd9dec Merge branch 'main' into next 2024-09-13 11:01:04 -07:00
Bobbin Threadbare
e34900c7d8 chore: update version to v0.10.1 2024-09-13 10:58:06 -07:00
Santiago Pittella
2b184cd4ca feat: add de/serialization to InOrderIndex and PartialMmr (#329) 2024-09-13 08:47:46 -07:00
Bobbin Threadbare
913384600d chore: fix typos 2024-09-11 16:52:21 -07:00
Qyriad
ae807a47ae feat: implement transactional Smt insertion (#327)
* feat(smt): impl constructing leaves that don't yet exist

This commit implements 'prospective leaf construction' -- computing
sparse Merkle tree leaves for a key-value insertion without actually
performing that insertion.

For SimpleSmt, this is trivial, since the leaf type is simply the value
being inserted.

For the full Smt, the new leaf payload depends on the existing payload
in that leaf. Since almost all leaves are very small, we can just clone
the leaf and modify a copy.

This will allow us to perform more general prospective changes on Merkle
trees.

* feat(smt): export get_value() in the trait

* feat(smt): implement generic prospective insertions

This commit adds two methods to SparseMerkleTree: compute_mutations()
and apply_mutations(), which respectively create and consume a new
MutationSet type. This type represents as set of changes to a
SparseMerkleTree that haven't happened yet, and can be queried on to
ensure a set of insertions result in the correct tree root before
finalizing and committing the mutation.

This is a direct step towards issue 222, and will directly enable
removing Merkle tree clones in miden-node InnerState::apply_block().

As part of this change, SparseMerkleTree now requires its Key to be Ord
and its Leaf to be Clone (both bounds which were already met by existing
implementations). The Ord bound could instead be changed to Eq + Hash,
if MutationSet were changed to use a HashMap instead of a BTreeMap.

* chore(smt): refactor empty node construction to helper function
2024-09-11 16:49:57 -07:00
Paul-Henry Kajfasz
f4a9d5b027 Merge pull request #323 from 0xPolygonMiden/phklive-consistent-ci
Update `Makefile` and `CI`
2024-08-22 08:22:20 -07:00
Paul-Henry Kajfasz
ee42d87121 Replace i. by 1. 2024-08-22 16:14:19 +01:00
Paul-Henry Kajfasz
b1cb2b6ec3 Fix comments 2024-08-22 15:21:59 +01:00
Paul-Henry Kajfasz
e4a9a2ac00 Updated test in workflow 2024-08-21 16:53:28 +01:00
Paul-Henry Kajfasz
c5077b1683 updated readme 2024-08-21 14:18:41 +01:00
Paul-Henry Kajfasz
2e74028fd4 Updated makefile 2024-08-21 14:11:17 +01:00
Paul-Henry Kajfasz
8bf6ef890d fmt 2024-08-21 14:04:23 +01:00
Paul-Henry Kajfasz
e2aeb25e01 Updated doc comments 2024-08-21 14:03:43 +01:00
Paul-Henry Kajfasz
790846cc73 Merge next 2024-08-21 09:29:39 +01:00
Paul-Henry Kajfasz
4cb6bed428 Updated changelog + added release to no-std 2024-08-19 14:37:58 +01:00
Bobbin Threadbare
a12e62ff22 feat: improve MMR api (#324) 2024-08-18 09:35:12 -07:00
Paul-Henry Kajfasz
9aa4987858 Merge branch 'phklive-consistent-ci' of github.com:0xPolygonMiden/crypto into phklive-consistent-ci 2024-08-16 17:29:29 -07:00
Paul-Henry Kajfasz
70a0a1e970 Removed Makefile.toml 2024-08-16 17:29:09 -07:00
Paul-Henry Kajfasz
025fbb66a9 Update README.md change miden-crypto to crypto 2024-08-17 01:21:19 +01:00
Paul-Henry Kajfasz
5ee5e8554b Ran pre-commit 2024-08-16 16:12:17 -07:00
Paul-Henry Kajfasz
ac3c6976bd Updated Changelog + pre commit 2024-08-16 16:09:51 -07:00
Paul-Henry Kajfasz
374a10f340 Updated ci + added scripts 2024-08-16 15:32:03 -07:00
Paul-Henry Kajfasz
ad0f472708 Updated Makefile and Readme 2024-08-16 15:07:27 -07:00
Bobbin Threadbare
8bb893345b chore: update rust version badge 2024-08-06 17:00:17 -07:00
Bobbin Threadbare
d92fae7f82 chore: update rust version badge 2024-08-06 16:59:31 -07:00
Bobbin Threadbare
b171575776 merge v0.10.0 release 2024-08-06 16:58:00 -07:00
Bobbin Threadbare
dfdd5f722f chore: fix lints 2024-08-06 16:52:46 -07:00
Bobbin Threadbare
9f63b50510 chore: increment crate version to v0.10.0 and update changelog 2024-08-06 16:42:50 -07:00
Elias Rad
d6ab367d32 chore: fix typos (#321) 2024-07-24 11:35:57 -07:00
Al-Kindi-0
b06cfa3c03 docs: update RPO with a comment on security given domain separation (#320) 2024-06-04 22:54:51 -07:00
Al-Kindi-0
8556c8fc43 fix: encoding Falcon secret key basis polynomials (#319) 2024-05-28 23:20:28 -07:00
Augusto Hack
78ac70120d fix: hex_to_bytes can be used for data besides RpoDigests (#317) 2024-05-13 13:13:02 -07:00
Bobbin Threadbare
ccde10af13 chore: update changelog 2024-05-12 03:17:06 +08:00
Al-Kindi-0
f967211b5a feat: migrate to new Winterfell (#315) 2024-05-12 03:09:27 +08:00
Augusto Hack
d58c717956 rpo/rpx: export digest error enum (#313) 2024-05-12 03:09:24 +08:00
Augusto Hack
c0743adac9 Rpo256: Add RpoDigest conversions (#311) 2024-05-12 03:09:21 +08:00
Bobbin Threadbare
f72add58cd chore: increment crate version to v0.9.3 and update changelog 2024-04-24 01:02:47 -07:00
Menko
63f97e5621 feat: add rpx random coin (#307) 2024-04-24 01:02:47 -07:00
Bobbin Threadbare
43fe7a1072 chore: increment crate version to 0.9.2 and update changelog 2024-04-21 01:14:18 -07:00
Al-Kindi-0
bb42388827 fix: bug in Falcon secret key basis order (#305) 2024-04-21 01:14:18 -07:00
Dominik Schmid
2a0ae70645 feature: adding serialization to the SMT (#304) 2024-04-21 01:14:18 -07:00
Bobbin Threadbare
da67f8c7e5 chore: increment doc version to v0.9.1 2024-04-02 13:07:02 -07:00
Bobbin Threadbare
9454e1a8ae chore: increment crate version to v0.9.1 2024-04-02 13:02:38 -07:00
Bobbin Threadbare
4bf087daf8 fix: decrement leaf count in simple SMT when inserting empty value (#303) 2024-04-02 13:01:00 -07:00
polydez
b4dc373925 feat: add leaf count to SimpleSmt (#302) 2024-04-02 12:07:00 -07:00
Bobbin Threadbare
4885f885a4 chore: update changelog 2024-03-24 08:42:38 -07:00
Bobbin Threadbare
5a2e917dd5 Tracking PR for v0.9.0 release (#278)
* chore: update crate version to v0.9.0
* chore: remove deprecated re-exports
* chore: remove Box re-export
* feat: implement pure-Rust keygen and signing for RpoFalcon512 (#285)
* feat: add reproducible builds (#296)
* fix: address a few issues for migrating Miden VM  (#298)
* feat: add RngCore supertrait for FeltRng (#299)

---------

Co-authored-by: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com>
Co-authored-by: Paul-Henry Kajfasz <42912740+phklive@users.noreply.github.com>
2024-03-24 08:38:08 -07:00
Bobbin Threadbare
2be17b74fb fix: add re-exports of vec and format macros 2024-03-18 12:35:00 -07:00
Bobbin Threadbare
b35e99c390 chore: increment crate version to v0.8.3 and update changelog 2024-03-18 11:29:06 -07:00
Paul Schoenfelder
4c8a9809ed fix: re-add unintentionally removed re-exported liballoc macros (#292) 2024-03-18 11:27:17 -07:00
Bobbin Threadbare
ce9b45fe77 chore: add badges to readme 2024-03-17 13:32:46 -07:00
Bobbin Threadbare
56d014898d chore: update copyright year 2024-03-17 13:25:26 -07:00
Bobbin Threadbare
8e81ccdb68 chore: increment version to v0.8.2 and update changelog 2024-03-17 13:23:44 -07:00
Paul Schoenfelder
999a64fca6 chore: handle deprecations in winterfell 0.8.3 release 2024-03-17 16:18:23 -04:00
Bobbin Threadbare
4bc4bea0db chore: update changelog 2024-02-21 23:59:36 -05:00
Augusto Hack
dbab0e9aa9 fix: clippy warnings (#280) 2024-02-21 20:55:02 -08:00
Bobbin Threadbare
24f72c986b chore: update changelog 2024-02-14 11:52:40 -08:00
Andrey Khmuro
cd4525c7ad refactor: update repo to be compatible with Winterfell 0.8 (#275) 2024-02-14 11:52:40 -08:00
Philippe Laferrière
552d90429b Remove TieredSmt (#277) 2024-02-14 11:52:40 -08:00
Philippe Laferrière
119c7e2b6d SmtProof: add accessors (#276)
* add accessors

* fmt

* comments
2024-02-14 11:52:40 -08:00
Philippe Laferrière
45e7e78118 Clone (#274) 2024-02-14 11:52:40 -08:00
Philippe Laferrière
a9475b2a2d reexport (#273) 2024-02-14 11:52:40 -08:00
Philippe Laferrière
e55b3ed2ce Introduce SmtProof (#270)
* add conversion for `SmtLeaf`

* introduce `SmtProof` scaffolding

* implement `verify_membership()`

* SmtLeaf: knows its index

* `SmtLeaf::index`

* `SmtLeaf::get_value()` returns an Option

* fix `verify_membership()`

* impl `SmtProof::get`

* impl `into_parts()`

* `SmtProof::compute_root`

* use `SmtProof` in `Smt::open`

* `SmtLeaf` constructors

* Vec

* impl `Error` for `SmtLeafError`

* fix std Error

* move Word/Digest conversions to LeafIndex

* `SmtProof::new()` returns an error

* `SparseMerkleTree::path_and_leaf_to_opening`

* `SmtLeaf`: serializable/deserializable

* `SmtProof`: serializable/deserializable

* add tests for SmtLeaf serialization

* move `SmtLeaf` to submodule

* use constructors internally

* fix docs

* Add `Vec`

* add `Vec` to tests

* no_std use statements

* fmt

* `Errors`: make heading

* use `SMT_DEPTH`

* SmtLeaf single case: check leaf index

* Multiple case: check consistency with leaf index

* use `pub(super)` instead of `pub(crate)`

* use `pub(super)`

* `SmtLeaf`: add `num_entries()` accessor

* Fix `SmtLeaf` serialization

* improve leaf serialization tests
2024-02-14 11:52:40 -08:00
Bobbin Threadbare
61a0764a61 fix: peak index calculation in MmrProof 2024-02-14 11:52:40 -08:00
Philippe Laferrière
3d71a9b59b Smt: remove inner nodes when removing value (#269) 2024-02-14 11:52:40 -08:00
Philippe Laferrière
da12fd258a Add missing methods to Smt (#268) 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
5fcf98669d feat: add PartialMmr::from_parts() constructor 2024-02-14 11:52:40 -08:00
Philippe Laferrière
1cdd3dbbfa Add methods to Smt necessary for VM tests (#264)
* Smt::inner_nodes

* Add conversion Smt -> MerkleStore

* add docstring to `Smt`

* add to docstring

* fmt

* add `leaves()` method to `Smt`

* add `kv_pairs` functions

* rewrite `into_elements()` in terms of `into_kv_pairs()`

* change docstring
2024-02-14 11:52:40 -08:00
Bobbin Threadbare
d59ffe274a feat: add Debug and Clone derives for Falcon signature 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
727ed8fb3e docs: minor padding comment update 2024-02-14 11:52:40 -08:00
Al-Kindi-0
0acceaa526 fix: always pad bytes with 10*0 (#267) 2024-02-14 11:52:40 -08:00
Michael Birch
3882e0f719 fix(dsa): fix deserialization logic (#266) 2024-02-14 11:52:40 -08:00
Augusto F. Hack
70e39e7b39 partialmmr: Method add with support for a single peak and tracking
fixes: #258
2024-02-14 11:52:40 -08:00
Philippe Laferrière
5596db7868 Implement Smt struct (replacement to TieredSmt) (#254) 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
a933ff2fa0 refactor: remove obsolete traits 2024-02-14 11:52:40 -08:00
Philippe Laferrière
8ea37904e3 Introduce SparseMerkleTree trait (#245) 2024-02-14 11:52:40 -08:00
Augusto F. Hack
1004246bfe ci: verify docs syntax 2024-02-14 11:52:40 -08:00
Augusto F. Hack
dae9de9068 docs: fix warnings 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
7e9d4a4316 feat: add to_hex() to RpoDigest and RpxDigest 2024-02-14 11:52:40 -08:00
Al-Kindi-0
c9ab3beccc New padding rule for RPX (#236)
* feat: new padding rule for RPX
* fix: documentation on security
2024-02-14 11:52:40 -08:00
cristiantroy
260592f8e7 Fix: typos (#249)
* tests: fix typos
* full.rs: fix typo
* CONTRIBUTING: fix typo
2024-02-14 11:52:40 -08:00
Bobbin Threadbare
6b5db8a6db fix: clippy 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
3ebee98b0f feat: add PartialMmr::is_tracked() 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
457c985a92 refactor: remove sve feature flag 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
f894ed9cde chore: update CI.yaml 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
ac7593a13c chore: update CI jobs 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
004a3bc7a8 docs: update changelog and readme 2024-02-14 11:52:40 -08:00
Grzegorz Swirski
479fe5e649 feat: use AVX2 instructions whenever available 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
a0f533241f fix: bugfix in PartialMmr apply delta 2024-02-14 11:52:40 -08:00
Al-Kindi-0
05309b19bb chore: export default Winterfell randomcoin 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
be1d631630 feat: add Clone derive to PartialMmr 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
4d0d8d3058 refactor: return MmrPeaks from PartialMmr::peaks() 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
59d93cb8ba fix: typos 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
9baddfd138 feat: implement inner_nodes() iterator for PartialMmr 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
8f92f44a55 feat: add serialization to RpoRandomCoin 2024-02-14 11:52:40 -08:00
Al-Kindi-0
36d3b8dc46 feat: move RpoRandomCoin and define Rng trait
nits: minor

chore: update log and readme
2024-02-14 11:52:40 -08:00
Augusto F. Hack
7e13346e04 serde: for MerklePath, ValuePath, and RootPath 2024-02-14 11:52:40 -08:00
Philippe Laferrière
9a18ed6749 Implement SimpleSmt::set_subtree (#232)
* recompute_nodes_from_indeX_to_root

* MerkleError variant

* set_subtree

* test_simplesmt_set_subtree

* test_simplesmt_set_subtree_entire_tree

* test

* set_subtree: return root
2024-02-14 11:52:40 -08:00
Augusto F. Hack
df2650eb1f bugfix: TSMT failed to verify empty word for depth 64.
When a prefix is pushed to the depth 64, the entry list includes only
the values different than ZERO. This is required, since each block
represents a 2^192 values.

The bug was in the proof membership code, that failed to handle the case
of a key that was not in the list, because the depth is 64 and the value
was not set.
2024-02-14 11:52:40 -08:00
Philippe Laferrière
18310a89f0 MmrPeaks::hash_peaks() returns Digest (#230) 2024-02-14 11:52:40 -08:00
Philippe Laferrière
d719cc2663 Remove ExactSizeIterator constraint from SimpleSmt::with_leaves() (#228)
* Change InvalidNumEntries error

* max computation

* remove length check

* remove ExactSizeIterator constraint

* fix InvalidNumEntries error condition

* 2_usize
2024-02-14 11:52:40 -08:00
Augusto F. Hack
fa475d1929 simplesmt: simplify duplicate check 2024-02-14 11:52:40 -08:00
Philippe Laferrière
25b8cb64ba Introduce SimpleSmt::with_contiguous_leaves() (#227)
* with_contiguous_leaves

* test
2024-02-14 11:52:40 -08:00
Augusto F. Hack
389fcb03c2 simplesmt: bugfix, index must be validated before modifying the tree 2024-02-14 11:52:40 -08:00
Austin Abell
b7cb346e22 feat: memoize Signature polynomial decoding 2024-02-14 11:52:40 -08:00
Philippe Laferriere
fd480f827a Consuming iterator for RpoDigest 2024-02-14 11:52:40 -08:00
Augusto F. Hack
9f95582654 mmr: add into_parts for the peaks 2024-02-14 11:52:40 -08:00
Augusto F. Hack
1f92d5417a simple_smt: reduce serialized size, use static hashes of the empty word 2024-02-14 11:52:40 -08:00
Augusto F. Hack
9b0ce0810b mmr: support accumulator of older forest versions 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
938250453a chore: update changelog 2024-02-14 11:52:40 -08:00
Al-Kindi-0
9ccac2baf0 chore: bump winterfell release to .7 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
525062d023 docs: update bench readme 2024-02-14 11:52:40 -08:00
Augusto F. Hack
3a5264c428 mmr: support proofs with older forest versions 2024-02-14 11:52:40 -08:00
Augusto F. Hack
a8acc0b39d mmr: support arbitrary from/to delta updates 2024-02-14 11:52:40 -08:00
Augusto F. Hack
5f2d170435 mmr: publicly export MmrDelta 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
9d52958f64 docs: update changelog 2024-02-14 11:52:40 -08:00
Al-Kindi-0
a2a26e2aba docs: added RPX benchmarks 2024-02-14 11:52:40 -08:00
Al-Kindi-0
3125144445 feat: RPX (xHash12) hash function implementation 2024-02-14 11:52:40 -08:00
Augusto F. Hack
f33a982f29 rpo: added conversions for digest 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
41f03fbe91 chore: update main readme 2024-02-14 11:52:40 -08:00
Augusto F. Hack
65495aeb18 config: add .editorconfig 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
0a2d440524 chore: update crate version to v0.8 2024-02-14 11:52:40 -08:00
Bobbin Threadbare
c86bdc6d51 Merge pull request #226 from shuoer86/main
chore: fix typos
2023-11-26 15:18:25 -08:00
shuoer86
650508cbc9 chore: fix typos 2023-11-26 21:19:03 +08:00
Augusto Hack
012ad5ae93 Merge pull request #195 from 0xPolygonMiden/hacka-partial-mmr2
mmr: added partial mmr
2023-10-19 20:30:24 +02:00
Augusto F. Hack
bde20f9752 mmr: added partial mmr 2023-10-19 20:15:49 +02:00
Bobbin Threadbare
7f3d4b8966 fix: RPO Falcon build on Windows 2023-10-10 15:16:51 -07:00
121 changed files with 18116 additions and 7965 deletions

3
.config/nextest.toml Normal file
View File

@@ -0,0 +1,3 @@
[profile.default]
failure-output = "immediate-final"
fail-fast = false

20
.editorconfig Normal file
View File

@@ -0,0 +1,20 @@
# Documentation available at editorconfig.org
root=true
[*]
ident_style = space
ident_size = 4
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.rs]
max_line_length = 100
[*.md]
trim_trailing_whitespace = false
[*.yml]
ident_size = 2

25
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,25 @@
# Runs build related jobs.
name: build
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
no-std:
name: Build for no-std
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
steps:
- uses: actions/checkout@main
- name: Build for no-std
run: |
rustup update --no-self-update ${{ matrix.toolchain }}
rustup target add wasm32-unknown-unknown
make build-no-std

23
.github/workflows/changelog.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
# Runs changelog related jobs.
# CI job heavily inspired by: https://github.com/tarides/changelog-check-action
name: changelog
on:
pull_request:
types: [opened, reopened, synchronize, labeled, unlabeled]
jobs:
changelog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@main
with:
fetch-depth: 0
- name: Check for changes in changelog
env:
BASE_REF: ${{ github.event.pull_request.base.ref }}
NO_CHANGELOG_LABEL: ${{ contains(github.event.pull_request.labels.*.name, 'no changelog') }}
run: ./scripts/check-changelog.sh "${{ inputs.changelog }}"
shell: bash

View File

@@ -1,101 +0,0 @@
name: CI
on:
push:
branches:
- main
pull_request:
types: [opened, repoened, synchronize]
jobs:
build:
name: Build ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
target: [wasm32-unknown-unknown]
args: [--no-default-features --target wasm32-unknown-unknown]
steps:
- uses: actions/checkout@main
with:
submodules: recursive
- name: Install rust
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
override: true
- run: rustup target add ${{matrix.target}}
- name: Test
uses: actions-rs/cargo@v1
with:
command: build
args: ${{matrix.args}}
test:
name: Test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.features}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
features: ["--features default,std,serde", --no-default-features]
steps:
- uses: actions/checkout@main
with:
submodules: recursive
- name: Install rust
uses: actions-rs/toolchain@v1
with:
toolchain: ${{matrix.toolchain}}
override: true
- name: Test
uses: actions-rs/cargo@v1
with:
command: test
args: ${{matrix.features}}
clippy:
name: Clippy with ${{matrix.features}}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
features: ["--features default,std,serde", --no-default-features]
steps:
- uses: actions/checkout@main
with:
submodules: recursive
- name: Install minimal nightly with clippy
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly
components: clippy
override: true
- name: Clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all ${{matrix.features}} -- -D clippy::all -D warnings
rustfmt:
name: rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Install minimal stable with rustfmt
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
components: rustfmt
override: true
- name: rustfmt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check

53
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,53 @@
# Runs linting related jobs.
name: lint
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
clippy:
name: clippy nightly on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Clippy
run: |
rustup update --no-self-update nightly
rustup +nightly component add clippy
make clippy
rustfmt:
name: rustfmt check nightly on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Rustfmt
run: |
rustup update --no-self-update nightly
rustup +nightly component add rustfmt
make format-check
doc:
name: doc stable on ubuntu-latest
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Build docs
run: |
rustup update --no-self-update
make doc
version:
name: check rust version consistency
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
with:
profile: minimal
override: true
- name: check rust versions
run: ./scripts/check-rust-version.sh

28
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,28 @@
# Runs test related jobs.
name: test
on:
push:
branches: [main, next]
pull_request:
types: [opened, reopened, synchronize]
jobs:
test:
name: test ${{matrix.toolchain}} on ${{matrix.os}} with ${{matrix.args}}
runs-on: ${{matrix.os}}-latest
strategy:
fail-fast: false
matrix:
toolchain: [stable, nightly]
os: [ubuntu]
args: [default, no-std]
timeout-minutes: 30
steps:
- uses: actions/checkout@main
- uses: taiki-e/install-action@nextest
- name: Perform tests
run: |
rustup update --no-self-update ${{matrix.toolchain}}
make test-${{matrix.args}}

7
.gitignore vendored
View File

@@ -2,12 +2,11 @@
# will have compiled files and executables
/target/
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
Cargo.lock
# These are backup files generated by rustfmt
**/*.rs.bk
# Generated by cmake
cmake-build-*
# VS Code
.vscode/

3
.gitmodules vendored
View File

@@ -1,3 +0,0 @@
[submodule "PQClean"]
path = PQClean
url = https://github.com/PQClean/PQClean.git

View File

@@ -1,43 +1,34 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-json
- id: check-toml
- id: pretty-format-json
- id: check-added-large-files
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: detect-private-key
- repo: https://github.com/hackaugusto/pre-commit-cargo
rev: v1.0.0
hooks:
# Allows cargo fmt to modify the source code prior to the commit
- id: cargo
name: Cargo fmt
args: ["+stable", "fmt", "--all"]
stages: [commit]
# Requires code to be properly formatted prior to pushing upstream
- id: cargo
name: Cargo fmt --check
args: ["+stable", "fmt", "--all", "--check"]
stages: [push, manual]
- id: cargo
name: Cargo check --all-targets
args: ["+stable", "check", "--all-targets"]
- id: cargo
name: Cargo check --all-targets --no-default-features
args: ["+stable", "check", "--all-targets", "--no-default-features"]
- id: cargo
name: Cargo check --all-targets --features default,std,serde
args: ["+stable", "check", "--all-targets", "--features", "default,std,serde"]
# Unlike fmt, clippy will not be automatically applied
- id: cargo
name: Cargo clippy
args: ["+nightly", "clippy", "--workspace", "--", "--deny", "clippy::all", "--deny", "warnings"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-json
- id: check-toml
- id: pretty-format-json
- id: check-added-large-files
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: detect-private-key
- repo: local
hooks:
- id: lint
name: Make lint
stages: [commit]
language: rust
entry: make lint
- id: doc
name: Make doc
stages: [commit]
language: rust
entry: make doc
- id: check
name: Make check
stages: [commit]
language: rust
entry: make check

View File

@@ -1,27 +1,120 @@
## 0.13.0 (2024-11-24)
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346).
- Added RPO-STARK based DSA (#349).
- Added benchmarks for DSA implementations (#354).
- Implemented deterministic RPO-STARK based DSA (#358).
## 0.12.0 (2024-10-30)
- [BREAKING] Updated Winterfell dependency to v0.10 (#338).
- Added parallel implementation of `Smt::with_entries()` with significantly better performance when the `concurrent` feature is enabled (#341).
## 0.11.0 (2024-10-17)
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
- Standardized CI and Makefile across Miden repos (#323).
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
- Changed padding rule for RPO/RPX hash functions (#318).
- [BREAKING] Changed return value of the `Mmr::verify()` and `MerklePath::verify()` from `bool` to `Result<>` (#335).
- Added `is_empty()` functions to the `SimpleSmt` and `Smt` structures. Added `EMPTY_ROOT` constant to the `SparseMerkleTree` trait (#337).
## 0.10.3 (2024-09-25)
- Implement `get_size_hint` for `Smt` (#331).
## 0.10.2 (2024-09-25)
- Implement `get_size_hint` for `RpoDigest` and `RpxDigest` and expose constants for their serialized size (#330).
## 0.10.1 (2024-09-13)
- Added `Serializable` and `Deserializable` implementations for `PartialMmr` and `InOrderIndex` (#329).
## 0.10.0 (2024-08-06)
- Added more `RpoDigest` and `RpxDigest` conversions (#311).
- [BREAKING] Migrated to Winterfell v0.9 (#315).
- Fixed encoding of Falcon secret key (#319).
## 0.9.3 (2024-04-24)
- Added `RpxRandomCoin` struct (#307).
## 0.9.2 (2024-04-21)
- Implemented serialization for the `Smt` struct (#304).
- Fixed a bug in Falcon signature generation (#305).
## 0.9.1 (2024-04-02)
- Added `num_leaves()` method to `SimpleSmt` (#302).
## 0.9.0 (2024-03-24)
- [BREAKING] Removed deprecated re-exports from liballoc/libstd (#290).
- [BREAKING] Refactored RpoFalcon512 signature to work with pure Rust (#285).
- [BREAKING] Added `RngCore` as supertrait for `FeltRng` (#299).
# 0.8.4 (2024-03-17)
- Re-added unintentionally removed re-exported liballoc macros (`vec` and `format` macros).
# 0.8.3 (2024-03-17)
- Re-added unintentionally removed re-exported liballoc macros (#292).
# 0.8.2 (2024-03-17)
- Updated `no-std` approach to be in sync with winterfell v0.8.3 release (#290).
## 0.8.1 (2024-02-21)
- Fixed clippy warnings (#280)
## 0.8.0 (2024-02-14)
- Implemented the `PartialMmr` data structure (#195).
- Implemented RPX hash function (#201).
- Added `FeltRng` and `RpoRandomCoin` (#237).
- Accelerated RPO/RPX hash functions using AVX512 instructions (#234).
- Added `inner_nodes()` method to `PartialMmr` (#238).
- Improved `PartialMmr::apply_delta()` (#242).
- Refactored `SimpleSmt` struct (#245).
- Replaced `TieredSmt` struct with `Smt` struct (#254, #277).
- Updated Winterfell dependency to v0.8 (#275).
## 0.7.1 (2023-10-10)
- Fixed RPO Falcon signature build on Windows.
## 0.7.0 (2023-10-05)
* Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
* Implemented clearing of nodes in `TieredSmt` (#173).
* Added ability to generate inclusion proofs for `TieredSmt` (#174).
* Implemented Falcon DSA (#179).
* Added conditional `serde`` support for various structs (#180).
* Implemented benchmarking for `TieredSmt` (#182).
* Added more leaf traversal methods for `MerkleStore` (#185).
* Added SVE acceleration for RPO hash function (#189).
- Replaced `MerklePathSet` with `PartialMerkleTree` (#165).
- Implemented clearing of nodes in `TieredSmt` (#173).
- Added ability to generate inclusion proofs for `TieredSmt` (#174).
- Implemented Falcon DSA (#179).
- Added conditional `serde`` support for various structs (#180).
- Implemented benchmarking for `TieredSmt` (#182).
- Added more leaf traversal methods for `MerkleStore` (#185).
- Added SVE acceleration for RPO hash function (#189).
## 0.6.0 (2023-06-25)
* [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
* [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
* Added initial implementation of `PartialMerkleTree` (#156).
- [BREAKING] Added support for recording capabilities for `MerkleStore` (#162).
- [BREAKING] Refactored Merkle struct APIs to use `RpoDigest` instead of `Word` (#157).
- Added initial implementation of `PartialMerkleTree` (#156).
## 0.5.0 (2023-05-26)
* Implemented `TieredSmt` (#152, #153).
* Implemented ability to extract a subset of a `MerkleStore` (#151).
* Cleaned up `SimpleSmt` interface (#149).
* Decoupled hashing and padding of peaks in `Mmr` (#148).
* Added `inner_nodes()` to `MerkleStore` (#146).
- Implemented `TieredSmt` (#152, #153).
- Implemented ability to extract a subset of a `MerkleStore` (#151).
- Cleaned up `SimpleSmt` interface (#149).
- Decoupled hashing and padding of peaks in `Mmr` (#148).
- Added `inner_nodes()` to `MerkleStore` (#146).
## 0.4.0 (2023-04-21)
@@ -69,6 +162,6 @@
- Initial release on crates.io containing the cryptographic primitives used in Miden VM and the Miden Rollup.
- Hash module with the BLAKE3 and Rescue Prime Optimized hash functions.
- BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output.
- RPO is implemented with 256-bit output.
- BLAKE3 is implemented with 256-bit, 192-bit, or 160-bit output.
- RPO is implemented with 256-bit output.
- Merkle module, with a set of data structures related to Merkle trees, implemented using the RPO hash function.

View File

@@ -72,7 +72,7 @@ For example, a new change to the AIR crate might have the following message: `fe
// ================================================================================
```
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's prefferable to run linting locally before push:
- [Rustfmt](https://github.com/rust-lang/rustfmt) and [Clippy](https://github.com/rust-lang/rust-clippy) linting is included in CI pipeline. Anyways it's preferable to run linting locally before push:
```
cargo fix --allow-staged --allow-dirty --all-targets --all-features; cargo fmt; cargo clippy --workspace --all-targets --all-features -- -D warnings
```

1329
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
[package]
name = "miden-crypto"
version = "0.7.0"
version = "0.14.0"
description = "Miden Cryptographic primitives"
authors = ["miden contributors"]
readme = "README.md"
license = "MIT"
repository = "https://github.com/0xPolygonMiden/crypto"
documentation = "https://docs.rs/miden-crypto/0.7.0"
documentation = "https://docs.rs/miden-crypto/0.14.0"
categories = ["cryptography", "no-std"]
keywords = ["miden", "crypto", "hash", "merkle"]
edition = "2021"
rust-version = "1.73"
rust-version = "1.82"
[[bin]]
name = "miden-crypto"
@@ -19,6 +19,10 @@ bench = false
doctest = false
required-features = ["executable"]
[[bench]]
name = "dsa"
harness = false
[[bench]]
name = "hash"
harness = false
@@ -27,32 +31,68 @@ harness = false
name = "smt"
harness = false
[[bench]]
name = "smt-subtree"
harness = false
required-features = ["internal"]
[[bench]]
name = "merkle"
harness = false
[[bench]]
name = "smt-with-entries"
harness = false
[[bench]]
name = "store"
harness = false
[features]
default = ["std"]
executable = ["dep:clap", "dep:rand_utils", "std"]
serde = ["dep:serde", "serde?/alloc", "winter_math/serde"]
std = ["blake3/std", "dep:cc", "dep:libc", "winter_crypto/std", "winter_math/std", "winter_utils/std"]
sve = ["std"]
concurrent = ["dep:rayon"]
default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"]
internal = []
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [
"blake3/std",
"dep:cc",
"rand/std",
"rand/std_rng",
"winter-crypto/std",
"winter-math/std",
"winter-utils/std",
]
[dependencies]
blake3 = { version = "1.5", default-features = false }
clap = { version = "4.4", features = ["derive"], optional = true }
libc = { version = "0.2", default-features = false, optional = true }
rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true }
serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true }
winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false }
winter_math = { version = "0.6", package = "winter-math", default-features = false }
winter_utils = { version = "0.6", package = "winter-utils", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] }
getrandom = { version = "0.2", features = ["js"] }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false }
rand_chacha = { version = "0.3", default-features = false }
rand_core = { version = "0.6", default-features = false }
rand-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', package = "winter-rand-utils" , branch = 'al-zk', optional = true }
rayon = { version = "1.10", optional = true }
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
sha3 = { version = "0.10", default-features = false }
thiserror = { version = "2.0", default-features = false }
winter-air = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-crypto = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-prover = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-verifier = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-math = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
winter-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', branch = 'al-zk' }
[dev-dependencies]
assert_matches = { version = "1.5", default-features = false }
criterion = { version = "0.5", features = ["html_reports"] }
proptest = "1.3"
rand_utils = { version = "0.6", package = "winter-rand-utils" }
hex = { version = "0.4", default-features = false, features = ["alloc"] }
proptest = "1.5"
rand-utils = {git = 'https://github.com/Al-Kindi-0/winterfell', package = "winter-rand-utils" , branch = 'al-zk' }
seq-macro = { version = "0.3" }
[build-dependencies]
cc = { version = "1.0", features = ["parallel"], optional = true }
cc = { version = "1.2", optional = true, features = ["parallel"] }
glob = "0.3"

View File

@@ -1,6 +1,6 @@
MIT License
Copyright (c) 2023 Polygon Miden
Copyright (c) 2024 Polygon Miden
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

86
Makefile Normal file
View File

@@ -0,0 +1,86 @@
.DEFAULT_GOAL := help
.PHONY: help
help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
# -- variables --------------------------------------------------------------------------------------
WARNINGS=RUSTDOCFLAGS="-D warnings"
DEBUG_OVERFLOW_INFO=RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2"
# -- linting --------------------------------------------------------------------------------------
.PHONY: clippy
clippy: ## Run Clippy with configs
$(WARNINGS) cargo +nightly clippy --workspace --all-targets --all-features
.PHONY: fix
fix: ## Run Fix with configs
cargo +nightly fix --allow-staged --allow-dirty --all-targets --all-features
.PHONY: format
format: ## Run Format using nightly toolchain
cargo +nightly fmt --all
.PHONY: format-check
format-check: ## Run Format using nightly toolchain but only in check mode
cargo +nightly fmt --all --check
.PHONY: lint
lint: format fix clippy ## Run all linting tasks at once (Clippy, fixing, formatting)
# --- docs ----------------------------------------------------------------------------------------
.PHONY: doc
doc: ## Generate and check documentation
$(WARNINGS) cargo doc --all-features --keep-going --release
# --- testing -------------------------------------------------------------------------------------
.PHONY: test-default
test-default: ## Run tests with default features
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
.PHONY: test-no-std
test-no-std: ## Run tests with `no-default-features` (std)
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --no-default-features
.PHONY: test
test: test-default test-no-std ## Run all tests
# --- checking ------------------------------------------------------------------------------------
.PHONY: check
check: ## Check all targets and features for errors without code generation
cargo check --all-targets --all-features
# --- building ------------------------------------------------------------------------------------
.PHONY: build
build: ## Build with default features enabled
cargo build --release
.PHONY: build-no-std
build-no-std: ## Build without the standard library
cargo build --release --no-default-features --target wasm32-unknown-unknown
.PHONY: build-avx2
build-avx2: ## Build with avx2 support
RUSTFLAGS="-C target-feature=+avx2" cargo build --release
.PHONY: build-sve
build-sve: ## Build with sve support
RUSTFLAGS="-C target-feature=+sve" cargo build --release
# --- benchmarking --------------------------------------------------------------------------------
.PHONY: bench-tx
bench-tx: ## Run crypto benchmarks
cargo bench --features="concurrent"

Submodule PQClean deleted from c3abebf4ab

View File

@@ -1,64 +1,109 @@
# Miden Crypto
[![LICENSE](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/0xPolygonMiden/crypto/blob/main/LICENSE)
[![test](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/test.yml)
[![build](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml/badge.svg)](https://github.com/0xPolygonMiden/crypto/actions/workflows/build.yml)
[![RUST_VERSION](https://img.shields.io/badge/rustc-1.82+-lightgray.svg)](https://www.rust-lang.org/tools/install)
[![CRATE](https://img.shields.io/crates/v/miden-crypto)](https://crates.io/crates/miden-crypto)
This crate contains cryptographic primitives used in Polygon Miden.
## Hash
[Hash module](./src/hash) provides a set of cryptographic hash functions which are used by the Miden VM and the Miden rollup. Currently, these functions are:
* [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
* [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
- [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3.
- [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs.
- [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO.
For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/).
## Merkle
[Merkle module](./src/merkle/) provides a set of data structures related to Merkle trees. All these data structures are implemented using the RPO hash function described above. The data structures are:
* `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
* `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
* `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
* `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
* `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
* `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values.
- `MerkleStore`: a collection of Merkle trees of different heights designed to efficiently store trees with common subtrees. When instantiated with `RecordingMap`, a Merkle store records all accesses to the original data.
- `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64.
- `Mmr`: a Merkle mountain range structure designed to function as an append-only log.
- `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64.
- `PartialMmr`: a partial view of a Merkle mountain range structure.
- `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values.
- `Smt`: a Sparse Merkle tree (with compaction at depth 64), mapping 4-element keys to 4-element values.
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
The module also contains additional supporting components such as `NodeIndex`, `MerklePath`, and `MerkleError` to assist with tree indexation, opening proofs, and reporting inconsistent arguments/state.
## Signatures
[DAS module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
* `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the *hash-to-point* algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
[DSA module](./src/dsa) provides a set of digital signature schemes supported by default in the Miden VM. Currently, these schemes are:
For the above signatures, key generation and signing is available only in the `std` context (see [crate features](#crate-features) below), while signature verification is available in `no_std` context as well.
- `RPO Falcon512`: a variant of the [Falcon](https://falcon-sign.info/) signature scheme. This variant differs from the standard in that instead of using SHAKE256 hash function in the _hash-to-point_ algorithm we use RPO256. This makes the signature more efficient to verify in Miden VM.
For the above signatures, key generation, signing, and signature verification are available for both `std` and `no_std` contexts (see [crate features](#crate-features) below). However, in `no_std` context, the user is responsible for supplying the key generation and signing procedures with a random number generator.
## Pseudo-Random Element Generator
[Pseudo random element generator module](./src/rand/) provides a set of traits and data structures that facilitate generating pseudo-random elements in the context of Miden VM and Miden rollup. The module currently includes:
- `FeltRng`: a trait for generating random field elements and random 4 field elements.
- `RpoRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPO hash function.
- `RpxRandomCoin`: a struct implementing `FeltRng` as well as the [`RandomCoin`](https://github.com/facebook/winterfell/blob/main/crypto/src/random/mod.rs) trait using RPX hash function.
## Make commands
We use `make` to automate building, testing, and other processes. In most cases, `make` commands are wrappers around `cargo` commands with specific arguments. You can view the list of available commands in the [Makefile](Makefile), or run the following command:
```shell
make
```
## Crate features
This crate can be compiled with the following features:
* `std` - enabled by default and relies on the Rust standard library.
* `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
- `std` - enabled by default and relies on the Rust standard library.
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
Both of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.
To compile with `no_std`, disable default features via `--no-default-features` flag.
To compile with `no_std`, disable default features via `--no-default-features` flag or using the following command:
```shell
make build-no-std
```
### AVX2 acceleration
On platforms with [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable AVX2 acceleration, the code needs to be compiled with the `avx2` target feature enabled. For example:
```shell
make build-avx2
```
### SVE acceleration
On platforms with [SVE](https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)) support, RPO hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` feature enabled. This feature has an effect only if the platform exposes `target-feature=sve` flag. On some platforms (e.g., Graviton 3), for this flag to be set, the compilation must be done in "native" mode. For example, to enable SVE acceleration on Graviton 3, we can execute the following:
On platforms with [SVE](<https://en.wikipedia.org/wiki/AArch64#Scalable_Vector_Extension_(SVE)>) support, RPO and RPX hash function can be accelerated by using the vector processing unit. To enable SVE acceleration, the code needs to be compiled with the `sve` target feature enabled. For example:
```shell
RUSTFLAGS="-C target-cpu=native" cargo build --release --features sve
make build-sve
```
## Testing
You can use cargo defaults to test the library:
The best way to test the library is using our [Makefile](Makefile), this will enable you to use our pre-defined optimized testing commands:
```shell
cargo test
make test
```
However, some of the functions are heavy and might take a while for the tests to complete. In order to test in release mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
For example, some of the functions are heavy and might take a while for the tests to complete if using simply `cargo test`. In order to test in release and optimized mode, we have to replicate the test conditions of the development mode so all debug assertions can be verified.
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation.
We do that by enabling some special [flags](https://doc.rust-lang.org/cargo/reference/profiles.html) for the compilation (which we have set as a default in our [Makefile](Makefile)):
```shell
RUSTFLAGS="-C debug-assertions -C overflow-checks -C debuginfo=2" cargo test --release
```
## License
This project is [MIT licensed](./LICENSE).

View File

@@ -1,4 +1,6 @@
# Miden VM Hash Functions
# Benchmarks
## Miden VM Hash Functions
In the Miden VM, we make use of different hash functions. Some of these are "traditional" hash functions, like `BLAKE3`, which are optimized for out-of-STARK performance, while others are algebraic hash functions, like `Rescue Prime`, and are more optimized for a better performance inside the STARK. In what follows, we benchmark several such hash functions and compare against other constructions that are used by other proving systems. More precisely, we benchmark:
* **BLAKE3** as specified [here](https://github.com/BLAKE3-team/BLAKE3-specs/blob/master/blake3.pdf) and implemented [here](https://github.com/BLAKE3-team/BLAKE3) (with a wrapper exposed via this crate).
@@ -6,39 +8,43 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra
* **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions).
* **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs).
* **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate.
* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate.
## Comparison and Instructions
### Comparison and Instructions
### Comparison
#### Comparison
We benchmark the above hash functions using two scenarios. The first is a 2-to-1 $(a,b)\mapsto h(a,b)$ hashing where both $a$, $b$ and $h(a,b)$ are the digests corresponding to each of the hash functions.
The second scenario is that of sequential hashing where we take a sequence of length $100$ field elements and hash these to produce a single digest. The digests are $4$ field elements in a prime field with modulus $2^{64} - 2^{32} + 1$ (i.e., 32 bytes) for Poseidon, Rescue Prime and RPO, and an array `[u8; 32]` for SHA3 and BLAKE3.
#### Scenario 1: 2-to-1 hashing `h(a,b)`
##### Scenario 1: 2-to-1 hashing `h(a,b)`
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
| ------------------- | ------ | --------| --------- | --------- | ------- |
| Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us |
| Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us |
| Amazon Graviton 3 | 108 ns | | | | 5.3 us |
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us |
| Intel Core i5-8279U | 80 ns | | | | 8.7 us |
| Intel Xeon 8375C | 67 ns | | | | 8.2 us |
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- |
| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs |
| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs |
| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs |
| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | |
| AMD EPYC 9R14 | 83 ns | | | | 4.3 µs | 2.4 µs |
| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs |
| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | |
#### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
##### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])`
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 |
| ------------------- | -------| ------- | --------- | --------- | ------- |
| Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us |
| Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us |
| Amazon Graviton 3 | 1.4 us | | | | 69 us |
| AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us |
| Intel Core i5-8279U | 1.0 us | | | | 116 us |
| Intel Xeon 8375C | 0.8 ns | | | | 110 us |
| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 |
| ------------------- | -------| ------- | --------- | --------- | ------- | ------- |
| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs |
| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs |
| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs |
| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | |
| AMD EPYC 9R14 | 0.9 µs | | | | 56 µs | 32 µs |
| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs |
| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | |
Notes:
- On Graviton 3, RPO256 is run with SVE acceleration enabled.
- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled.
- On AMD EPYC 9R14, RPO256 and RPX256 are run with AVX2 acceleration enabled.
### Instructions
#### Instructions
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following:
```
@@ -50,3 +56,47 @@ To run the benchmarks for Rescue Prime, Poseidon and SHA3, clone the following [
```
cargo bench hash
```
## Miden VM DSA
We make use of the following digital signature algorithms (DSA) in the Miden VM:
* **RPO-Falcon512** as specified [here](https://falcon-sign.info/falcon.pdf) with the one difference being the use of the RPO hash function for the hash-to-point algorithm (Algorithm 3 in the previous reference) instead of SHAKE256.
* **RPO-STARK** as specified [here](https://eprint.iacr.org/2024/1553), where the parameters are the ones for the unique-decoding regime (UDR) with the two differences:
* We rely on Conjecture 1 in the [ethSTARK](https://eprint.iacr.org/2021/582) paper.
* The number of FRI queries is $30$ and the grinding factor is $12$ bits. Thus using the previous point we can argue that the modified version achieves at least $102$ bits of average-case existential unforgeability security against $2^{113}$-query bound adversaries that can obtain up to $2^{64}$ signatures under the same public key.
### Comparison and Instructions
#### Comparison
##### Key Generation
| DSA | RPO-Falcon512 | RPO-STARK |
| ------------------- | :-----------: | :-------: |
| Apple M1 Pro | 590 ms | 6 µs |
| Intel Core i5-8279U | 585 ms | 10 µs |
##### Signature Generation
| DSA | RPO-Falcon512 | RPO-STARK |
| ------------------- | :-----------: | :-------: |
| Apple M1 Pro | 1.5 ms | 78 ms |
| Intel Core i5-8279U | 1.8 ms | 130 ms |
##### Signature Verification
| DSA | RPO-Falcon512 | RPO-STARK |
| ------------------- | :-----------: | :-------: |
| Apple M1 Pro | 0.7 ms | 4.5 ms |
| Intel Core i5-8279U | 1.2 ms | 7.9 ms |
#### Instructions
Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks, clone the current repository, and from the root directory of the repo run the following:
```
cargo bench --bench dsa
```

88
benches/dsa.rs Normal file
View File

@@ -0,0 +1,88 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use miden_crypto::dsa::{
rpo_falcon512::SecretKey as FalconSecretKey, rpo_stark::SecretKey as RpoStarkSecretKey,
};
use rand_utils::rand_array;
fn key_gen_falcon(c: &mut Criterion) {
c.bench_function("Falcon public key generation", |bench| {
bench.iter_batched(|| FalconSecretKey::new(), |sk| sk.public_key(), BatchSize::SmallInput)
});
c.bench_function("Falcon secret key generation", |bench| {
bench.iter_batched(|| {}, |_| FalconSecretKey::new(), BatchSize::SmallInput)
});
}
fn key_gen_rpo_stark(c: &mut Criterion) {
c.bench_function("RPO-STARK public key generation", |bench| {
bench.iter_batched(
|| RpoStarkSecretKey::random(),
|sk| sk.public_key(),
BatchSize::SmallInput,
)
});
c.bench_function("RPO-STARK secret key generation", |bench| {
bench.iter_batched(|| {}, |_| RpoStarkSecretKey::random(), BatchSize::SmallInput)
});
}
fn signature_gen_falcon(c: &mut Criterion) {
c.bench_function("Falcon signature generation", |bench| {
bench.iter_batched(
|| (FalconSecretKey::new(), rand_array().into()),
|(sk, msg)| sk.sign(msg),
BatchSize::SmallInput,
)
});
}
fn signature_gen_rpo_stark(c: &mut Criterion) {
c.bench_function("RPO-STARK signature generation", |bench| {
bench.iter_batched(
|| (RpoStarkSecretKey::random(), rand_array().into()),
|(sk, msg)| sk.sign(msg),
BatchSize::SmallInput,
)
});
}
fn signature_ver_falcon(c: &mut Criterion) {
c.bench_function("Falcon signature verification", |bench| {
bench.iter_batched(
|| {
let sk = FalconSecretKey::new();
let msg = rand_array().into();
(sk.public_key(), msg, sk.sign(msg))
},
|(pk, msg, sig)| pk.verify(msg, &sig),
BatchSize::SmallInput,
)
});
}
fn signature_ver_rpo_stark(c: &mut Criterion) {
c.bench_function("RPO-STARK signature verification", |bench| {
bench.iter_batched(
|| {
let sk = RpoStarkSecretKey::random();
let msg = rand_array().into();
(sk.public_key(), msg, sk.sign(msg))
},
|(pk, msg, sig)| pk.verify(msg, &sig),
BatchSize::SmallInput,
)
});
}
criterion_group!(
dsa_group,
key_gen_falcon,
key_gen_rpo_stark,
signature_gen_falcon,
signature_gen_rpo_stark,
signature_ver_falcon,
signature_ver_rpo_stark
);
criterion_main!(dsa_group);

View File

@@ -3,6 +3,7 @@ use miden_crypto::{
hash::{
blake::Blake3_256,
rpo::{Rpo256, RpoDigest},
rpx::{Rpx256, RpxDigest},
},
Felt,
};
@@ -31,7 +32,6 @@ fn rpo256_2to1(c: &mut Criterion) {
fn rpo256_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
@@ -44,7 +44,6 @@ fn rpo256_sequential(c: &mut Criterion) {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
@@ -57,6 +56,52 @@ fn rpo256_sequential(c: &mut Criterion) {
});
}
fn rpx256_2to1(c: &mut Criterion) {
let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])];
c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| {
bench.iter(|| Rpx256::merge(black_box(&v)))
});
c.bench_function("RPX256 2-to-1 hashing (random)", |bench| {
bench.iter_batched(
|| {
[
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
Rpx256::hash(&rand_value::<u64>().to_le_bytes()),
]
},
|state| Rpx256::merge(&state),
BatchSize::SmallInput,
)
});
}
fn rpx256_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
c.bench_function("RPX256 sequential hashing (cached)", |bench| {
bench.iter(|| Rpx256::hash_elements(black_box(&v)))
});
c.bench_function("RPX256 sequential hashing (random)", |bench| {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
.expect("should not fail");
v
},
|state| Rpx256::hash_elements(&state),
BatchSize::SmallInput,
)
});
}
fn blake3_2to1(c: &mut Criterion) {
let v: [<Blake3_256 as Hasher>::Digest; 2] =
[Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])];
@@ -80,7 +125,6 @@ fn blake3_2to1(c: &mut Criterion) {
fn blake3_sequential(c: &mut Criterion) {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(Felt::new)
.collect::<Vec<Felt>>()
.try_into()
@@ -93,7 +137,6 @@ fn blake3_sequential(c: &mut Criterion) {
bench.iter_batched(
|| {
let v: [Felt; 100] = (0..100)
.into_iter()
.map(|_| Felt::new(rand_value()))
.collect::<Vec<Felt>>()
.try_into()
@@ -106,5 +149,13 @@ fn blake3_sequential(c: &mut Criterion) {
});
}
criterion_group!(hash_group, rpo256_2to1, rpo256_sequential, blake3_2to1, blake3_sequential);
criterion_group!(
hash_group,
rpx256_2to1,
rpx256_sequential,
rpo256_2to1,
rpo256_sequential,
blake3_2to1,
blake3_sequential
);
criterion_main!(hash_group);

66
benches/merkle.rs Normal file
View File

@@ -0,0 +1,66 @@
//! Benchmark for building a [`miden_crypto::merkle::MerkleTree`]. This is intended to be compared
//! with the results from `benches/smt-subtree.rs`, as building a fully balanced Merkle tree with
//! 256 leaves should indicate the *absolute best* performance we could *possibly* get for building
//! a depth-8 sparse Merkle subtree, though practically speaking building a fully balanced Merkle
//! tree will perform better than the sparse version. At the time of this writing (2024/11/24), this
//! benchmark is about four times more efficient than the equivalent benchmark in
//! `benches/smt-subtree.rs`.
use std::{hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE};
use rand_utils::prng_array;
fn balanced_merkle_even(c: &mut Criterion) {
c.bench_function("balanced-merkle-even", |b| {
b.iter_batched(
|| {
let entries: Vec<Word> =
(0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect();
assert_eq!(entries.len(), 256);
entries
},
|leaves| {
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
assert_eq!(tree.depth(), 8);
},
BatchSize::SmallInput,
);
});
}
fn balanced_merkle_rand(c: &mut Criterion) {
let mut seed = [0u8; 32];
c.bench_function("balanced-merkle-rand", |b| {
b.iter_batched(
|| {
let entries: Vec<Word> = (0..256).map(|_| generate_word(&mut seed)).collect();
assert_eq!(entries.len(), 256);
entries
},
|leaves| {
let tree = MerkleTree::new(hint::black_box(leaves)).unwrap();
assert_eq!(tree.depth(), 8);
},
BatchSize::SmallInput,
);
});
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(20))
.configure_from_args();
targets = balanced_merkle_even, balanced_merkle_rand
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

142
benches/smt-subtree.rs Normal file
View File

@@ -0,0 +1,142 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{build_subtree_for_bench, NodeIndex, SmtLeaf, SubtreeLeaf, SMT_DEPTH},
Felt, Word, ONE,
};
use rand_utils::prng_array;
use winter_utils::Randomizable;
const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256];
fn smt_subtree_even(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-even");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|n| {
// A single depth-8 subtree can have a maximum of 255 leaves.
let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
let key = RpoDigest::new([
generate_value(&mut seed),
ONE,
Felt::new(n),
Felt::new(leaf_index),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves.dedup_by_key(|leaf| leaf.col);
leaves
},
|leaves| {
// Benchmarked function.
let (subtree, _) = build_subtree_for_bench(
hint::black_box(leaves),
hint::black_box(SMT_DEPTH),
hint::black_box(SMT_DEPTH),
);
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
fn smt_subtree_random(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-rand");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let leaf_index: u8 = generate_value(&mut seed);
let key = RpoDigest::new([
ONE,
ONE,
Felt::new(i),
Felt::new(leaf_index as u64),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves
},
|leaves| {
let (subtree, _) = build_subtree_for_bench(
hint::black_box(leaves),
hint::black_box(SMT_DEPTH),
hint::black_box(SMT_DEPTH),
);
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(40))
.sample_size(60)
.configure_from_args();
targets = smt_subtree_even, smt_subtree_random
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View File

@@ -0,0 +1,71 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
use rand_utils::prng_array;
use winter_utils::Randomizable;
// 2^0, 2^4, 2^8, 2^12, 2^16
const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576];
fn smt_with_entries(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("smt-with-entries");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
prepare_entries(pair_count, &mut seed)
},
|entries| {
// Benchmarked function.
Smt::with_entries(hint::black_box(entries)).unwrap();
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_with_entries_group;
config = Criterion::default()
//.measurement_time(Duration::from_secs(960))
.measurement_time(Duration::from_secs(60))
.sample_size(10)
.configure_from_args();
targets = smt_with_entries
}
criterion_main!(smt_with_entries_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn prepare_entries(pair_count: u64, seed: &mut [u8; 32]) -> Vec<(RpoDigest, [Felt; 4])> {
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let count = pair_count as f64;
let idx = ((i as f64 / count) * (count)) as u64;
let key = RpoDigest::new([generate_value(seed), ONE, Felt::new(i), Felt::new(idx)]);
let value = generate_word(seed);
(key, value)
})
.collect();
entries
}
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View File

@@ -1,16 +1,21 @@
use core::mem::swap;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use miden_crypto::{merkle::SimpleSmt, Felt, Word};
use miden_crypto::{
merkle::{LeafIndex, SimpleSmt},
Felt, Word,
};
use rand_utils::prng_array;
use seq_macro::seq;
fn smt_rpo(c: &mut Criterion) {
// setup trees
let mut seed = [0u8; 32];
let mut trees = vec![];
let leaf = generate_word(&mut seed);
for depth in 14..=20 {
let leaves = ((1 << depth) - 1) as u64;
seq!(DEPTH in 14..=20 {
let leaves = ((1 << DEPTH) - 1) as u64;
for count in [1, leaves / 2, leaves] {
let entries: Vec<_> = (0..count)
.map(|i| {
@@ -18,50 +23,45 @@ fn smt_rpo(c: &mut Criterion) {
(i, word)
})
.collect();
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
trees.push((tree, count));
let mut tree = SimpleSmt::<DEPTH>::with_leaves(entries).unwrap();
// benchmark 1
let mut insert = c.benchmark_group("smt update_leaf".to_string());
{
let depth = DEPTH;
let key = count >> 2;
insert.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&(key, leaf),
|b, (key, leaf)| {
b.iter(|| {
tree.insert(black_box(LeafIndex::<DEPTH>::new(*key).unwrap()), black_box(*leaf));
});
},
);
}
insert.finish();
// benchmark 2
let mut path = c.benchmark_group("smt get_leaf_path".to_string());
{
let depth = DEPTH;
let key = count >> 2;
path.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&key,
|b, key| {
b.iter(|| {
tree.open(black_box(&LeafIndex::<DEPTH>::new(*key).unwrap()));
});
},
);
}
path.finish();
}
}
let leaf = generate_word(&mut seed);
// benchmarks
let mut insert = c.benchmark_group(format!("smt update_leaf"));
for (tree, count) in trees.iter_mut() {
let depth = tree.depth();
let key = *count >> 2;
insert.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&(key, leaf),
|b, (key, leaf)| {
b.iter(|| {
tree.update_leaf(black_box(*key), black_box(*leaf)).unwrap();
});
},
);
}
insert.finish();
let mut path = c.benchmark_group(format!("smt get_leaf_path"));
for (tree, count) in trees.iter_mut() {
let depth = tree.depth();
let key = *count >> 2;
path.bench_with_input(
format!("simple smt(depth:{depth},count:{count})"),
&key,
|b, key| {
b.iter(|| {
tree.get_leaf_path(black_box(*key)).unwrap();
});
},
);
}
path.finish();
});
}
criterion_group!(smt_group, smt_rpo);

View File

@@ -1,7 +1,12 @@
use criterion::{black_box, criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::merkle::{DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex, SimpleSmt};
use miden_crypto::Word;
use miden_crypto::{hash::rpo::RpoDigest, Felt};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{
DefaultMerkleStore as MerkleStore, LeafIndex, MerkleTree, NodeIndex, SimpleSmt,
SMT_MAX_DEPTH,
},
Felt, Word,
};
use rand_utils::{rand_array, rand_value};
/// Since MerkleTree can only be created when a power-of-two number of elements is used, the sample
@@ -15,7 +20,7 @@ fn random_rpo_digest() -> RpoDigest {
/// Generates a random `Word`.
fn random_word() -> Word {
rand_array::<Felt, 4>().into()
rand_array::<Felt, 4>()
}
/// Generates an index at the specified depth in `0..range`.
@@ -28,26 +33,26 @@ fn random_index(range: u64, depth: u8) -> NodeIndex {
fn get_empty_leaf_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_empty_leaf_simplesmt");
let depth = SimpleSmt::MAX_DEPTH;
const DEPTH: u8 = SMT_MAX_DEPTH;
let size = u64::MAX;
// both SMT and the store are pre-populated with empty hashes, accessing these values is what is
// being benchmarked here, so no values are inserted into the backends
let smt = SimpleSmt::new(depth).unwrap();
let smt = SimpleSmt::<DEPTH>::new().unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
b.iter_batched(
|| random_index(size, depth),
|| random_index(size, DEPTH),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
b.iter_batched(
|| random_index(size, depth),
|| random_index(size, DEPTH),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
@@ -104,15 +109,14 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let depth = smt.depth();
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| black_box(smt.get_node(index)),
BatchSize::SmallInput,
)
@@ -120,7 +124,7 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| black_box(store.get_node(root, index)),
BatchSize::SmallInput,
)
@@ -132,18 +136,18 @@ fn get_leaf_simplesmt(c: &mut Criterion) {
fn get_node_of_empty_simplesmt(c: &mut Criterion) {
let mut group = c.benchmark_group("get_node_of_empty_simplesmt");
let depth = SimpleSmt::MAX_DEPTH;
const DEPTH: u8 = SMT_MAX_DEPTH;
// both SMT and the store are pre-populated with the empty hashes, accessing the internal nodes
// of these values is what is being benchmarked here, so no values are inserted into the
// backends.
let smt = SimpleSmt::new(depth).unwrap();
let smt = SimpleSmt::<DEPTH>::new().unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let half_depth = depth / 2;
let half_depth = DEPTH / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("SimpleSmt", depth), |b| {
group.bench_function(BenchmarkId::new("SimpleSmt", DEPTH), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(smt.get_node(index)),
@@ -151,7 +155,7 @@ fn get_node_of_empty_simplesmt(c: &mut Criterion) {
)
});
group.bench_function(BenchmarkId::new("MerkleStore", depth), |b| {
group.bench_function(BenchmarkId::new("MerkleStore", DEPTH), |b| {
b.iter_batched(
|| random_index(half_size, half_depth),
|index| black_box(store.get_node(root, index)),
@@ -212,10 +216,10 @@ fn get_node_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let root = smt.root();
let half_depth = smt.depth() / 2;
let half_depth = SMT_MAX_DEPTH / 2;
let half_size = 2_u64.pow(half_depth as u32);
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
@@ -286,23 +290,24 @@ fn get_leaf_path_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let store = MerkleStore::from(&smt);
let depth = smt.depth();
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSmt", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|index| black_box(smt.get_path(index)),
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| {
black_box(smt.open(&LeafIndex::<SMT_MAX_DEPTH>::new(index.value()).unwrap()))
},
BatchSize::SmallInput,
)
});
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| random_index(size_u64, depth),
|| random_index(size_u64, SMT_MAX_DEPTH),
|index| black_box(store.get_path(root, index)),
BatchSize::SmallInput,
)
@@ -352,7 +357,7 @@ fn new(c: &mut Criterion) {
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>()
},
|l| black_box(SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l)),
|l| black_box(SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l)),
BatchSize::SmallInput,
)
});
@@ -367,7 +372,7 @@ fn new(c: &mut Criterion) {
.collect::<Vec<(u64, Word)>>()
},
|l| {
let smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, l).unwrap();
let smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(l).unwrap();
black_box(MerkleStore::from(&smt));
},
BatchSize::SmallInput,
@@ -433,16 +438,17 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
.enumerate()
.map(|(c, v)| (c.try_into().unwrap(), v.into()))
.collect::<Vec<(u64, Word)>>();
let mut smt = SimpleSmt::with_leaves(SimpleSmt::MAX_DEPTH, smt_leaves.clone()).unwrap();
let mut smt = SimpleSmt::<SMT_MAX_DEPTH>::with_leaves(smt_leaves.clone()).unwrap();
let mut store = MerkleStore::from(&smt);
let depth = smt.depth();
let root = smt.root();
let size_u64 = size as u64;
group.bench_function(BenchmarkId::new("SimpleSMT", size), |b| {
b.iter_batched(
|| (rand_value::<u64>() % size_u64, random_word()),
|(index, value)| black_box(smt.update_leaf(index, value)),
|(index, value)| {
black_box(smt.insert(LeafIndex::<SMT_MAX_DEPTH>::new(index).unwrap(), value))
},
BatchSize::SmallInput,
)
});
@@ -450,7 +456,7 @@ fn update_leaf_simplesmt(c: &mut Criterion) {
let mut store_root = root;
group.bench_function(BenchmarkId::new("MerkleStore", size), |b| {
b.iter_batched(
|| (random_index(size_u64, depth), random_word()),
|| (random_index(size_u64, SMT_MAX_DEPTH), random_word()),
|(index, value)| {
// The MerkleTree automatically updates its internal root, the Store maintains
// the old root and adds the new one. Here we update the root to have a fair

View File

@@ -1,40 +1,9 @@
fn main() {
#[cfg(feature = "std")]
compile_rpo_falcon();
#[cfg(all(target_feature = "sve", feature = "sve"))]
#[cfg(target_feature = "sve")]
compile_arch_arm64_sve();
}
#[cfg(feature = "std")]
fn compile_rpo_falcon() {
use std::path::PathBuf;
const RPO_FALCON_PATH: &str = "src/dsa/rpo_falcon512/falcon_c";
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.h");
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/falcon.c");
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.h");
println!("cargo:rerun-if-changed={RPO_FALCON_PATH}/rpo.c");
let target_dir: PathBuf = ["PQClean", "crypto_sign", "falcon-512", "clean"].iter().collect();
let common_dir: PathBuf = ["PQClean", "common"].iter().collect();
let scheme_files = glob::glob(target_dir.join("*.c").to_str().unwrap()).unwrap();
let common_files = glob::glob(common_dir.join("*.c").to_str().unwrap()).unwrap();
cc::Build::new()
.include(&common_dir)
.include(target_dir)
.files(scheme_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
.files(common_files.into_iter().map(|p| p.unwrap().to_string_lossy().into_owned()))
.file(format!("{RPO_FALCON_PATH}/falcon.c"))
.file(format!("{RPO_FALCON_PATH}/rpo.c"))
.flag("-O3")
.compile("rpo_falcon512");
}
#[cfg(all(target_feature = "sve", feature = "sve"))]
#[cfg(target_feature = "sve")]
fn compile_arch_arm64_sve() {
const RPO_SVE_PATH: &str = "arch/arm64-sve/rpo";

5
rust-toolchain.toml Normal file
View File

@@ -0,0 +1,5 @@
[toolchain]
channel = "1.82"
components = ["rustfmt", "rust-src", "clippy"]
targets = ["wasm32-unknown-unknown"]
profile = "minimal"

View File

@@ -2,20 +2,22 @@ edition = "2021"
array_width = 80
attr_fn_like_width = 80
chain_width = 80
#condense_wildcard_suffixes = true
#enum_discrim_align_threshold = 40
comment_width = 100
condense_wildcard_suffixes = true
fn_call_width = 80
#fn_single_line = true
#format_code_in_doc_comments = true
#format_macro_matchers = true
#format_strings = true
#group_imports = "StdExternalCrate"
#hex_literal_case = "Lower"
#imports_granularity = "Crate"
format_code_in_doc_comments = true
format_macro_matchers = true
group_imports = "StdExternalCrate"
hex_literal_case = "Lower"
imports_granularity = "Crate"
match_block_trailing_comma = true
newline_style = "Unix"
#normalize_doc_attributes = true
#reorder_impl_items = true
reorder_imports = true
reorder_modules = true
single_line_if_else_max_width = 60
single_line_let_else_max_width = 60
struct_lit_width = 40
struct_variant_width = 40
use_field_init_shorthand = true
use_try_shorthand = true
wrap_comments = true

21
scripts/check-changelog.sh Executable file
View File

@@ -0,0 +1,21 @@
#!/bin/bash
set -uo pipefail
CHANGELOG_FILE="${1:-CHANGELOG.md}"
if [ "${NO_CHANGELOG_LABEL}" = "true" ]; then
# 'no changelog' set, so finish successfully
echo "\"no changelog\" label has been set"
exit 0
else
# a changelog check is required
# fail if the diff is empty
if git diff --exit-code "origin/${BASE_REF}" -- "${CHANGELOG_FILE}"; then
>&2 echo "Changes should come with an entry in the \"CHANGELOG.md\" file. This behavior
can be overridden by using the \"no changelog\" label, which is used for changes
that are trivial / explicitely stated not to require a changelog entry."
exit 1
fi
echo "The \"CHANGELOG.md\" file has been updated."
fi

15
scripts/check-rust-version.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
# Get rust-toolchain.toml file channel
TOOLCHAIN_VERSION=$(grep 'channel' rust-toolchain.toml | sed -E 's/.*"(.*)".*/\1/')
# Get workspace Cargo.toml file rust-version
CARGO_VERSION=$(grep 'rust-version' Cargo.toml | sed -E 's/.*"(.*)".*/\1/')
# Check version match
if [ "$CARGO_VERSION" != "$TOOLCHAIN_VERSION" ]; then
echo "Mismatch in Cargo.toml: Expected $TOOLCHAIN_VERSION, found $CARGO_VERSION"
exit 1
fi
echo "Rust versions match ✅"

View File

@@ -1,3 +1,5 @@
//! Digital signature schemes supported by default in the Miden VM.
pub mod rpo_falcon512;
pub mod rpo_stark;

View File

@@ -1,55 +0,0 @@
use super::{LOG_N, MODULUS, PK_LEN};
use core::fmt;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FalconError {
KeyGenerationFailed,
PubKeyDecodingExtraData,
PubKeyDecodingInvalidCoefficient(u32),
PubKeyDecodingInvalidLength(usize),
PubKeyDecodingInvalidTag(u8),
SigDecodingTooBigHighBits(u32),
SigDecodingInvalidRemainder,
SigDecodingNonZeroUnusedBitsLastByte,
SigDecodingMinusZero,
SigDecodingIncorrectEncodingAlgorithm,
SigDecodingNotSupportedDegree(u8),
SigGenerationFailed,
}
impl fmt::Display for FalconError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use FalconError::*;
match self {
KeyGenerationFailed => write!(f, "Failed to generate a private-public key pair"),
PubKeyDecodingExtraData => {
write!(f, "Failed to decode public key: input not fully consumed")
}
PubKeyDecodingInvalidCoefficient(val) => {
write!(f, "Failed to decode public key: coefficient {val} is greater than or equal to the field modulus {MODULUS}")
}
PubKeyDecodingInvalidLength(len) => {
write!(f, "Failed to decode public key: expected {PK_LEN} bytes but received {len}")
}
PubKeyDecodingInvalidTag(byte) => {
write!(f, "Failed to decode public key: expected the first byte to be {LOG_N} but was {byte}")
}
SigDecodingTooBigHighBits(m) => {
write!(f, "Failed to decode signature: high bits {m} exceed 2048")
}
SigDecodingInvalidRemainder => {
write!(f, "Failed to decode signature: incorrect remaining data")
}
SigDecodingNonZeroUnusedBitsLastByte => {
write!(f, "Failed to decode signature: Non-zero unused bits in the last byte")
}
SigDecodingMinusZero => write!(f, "Failed to decode signature: -0 is forbidden"),
SigDecodingIncorrectEncodingAlgorithm => write!(f, "Failed to decode signature: not supported encoding algorithm"),
SigDecodingNotSupportedDegree(log_n) => write!(f, "Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided"),
SigGenerationFailed => write!(f, "Failed to generate a signature"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for FalconError {}

View File

@@ -1,402 +0,0 @@
/*
* Wrapper for implementing the PQClean API.
*/
#include <string.h>
#include "randombytes.h"
#include "falcon.h"
#include "inner.h"
#include "rpo.h"
#define NONCELEN 40
/*
* Encoding formats (nnnn = log of degree, 9 for Falcon-512, 10 for Falcon-1024)
*
* private key:
* header byte: 0101nnnn
* private f (6 or 5 bits by element, depending on degree)
* private g (6 or 5 bits by element, depending on degree)
* private F (8 bits by element)
*
* public key:
* header byte: 0000nnnn
* public h (14 bits by element)
*
* signature:
* header byte: 0011nnnn
* nonce 40 bytes
* value (12 bits by element)
*
* message + signature:
* signature length (2 bytes, big-endian)
* nonce 40 bytes
* message
* header byte: 0010nnnn
* value (12 bits by element)
* (signature length is 1+len(value), not counting the nonce)
*/
/* see falcon.h */
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
uint8_t *pk,
uint8_t *sk,
unsigned char *seed
) {
union
{
uint8_t b[FALCON_KEYGEN_TEMP_9];
uint64_t dummy_u64;
fpr dummy_fpr;
} tmp;
int8_t f[512], g[512], F[512];
uint16_t h[512];
inner_shake256_context rng;
size_t u, v;
/*
* Generate key pair.
*/
inner_shake256_init(&rng);
inner_shake256_inject(&rng, seed, sizeof seed);
inner_shake256_flip(&rng);
PQCLEAN_FALCON512_CLEAN_keygen(&rng, f, g, F, NULL, h, 9, tmp.b);
inner_shake256_ctx_release(&rng);
/*
* Encode private key.
*/
sk[0] = 0x50 + 9;
u = 1;
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
if (v == 0)
{
return -1;
}
u += v;
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9]);
if (v == 0)
{
return -1;
}
u += v;
v = PQCLEAN_FALCON512_CLEAN_trim_i8_encode(
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u,
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9]);
if (v == 0)
{
return -1;
}
u += v;
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
{
return -1;
}
/*
* Encode public key.
*/
pk[0] = 0x00 + 9;
v = PQCLEAN_FALCON512_CLEAN_modq_encode(
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1,
h, 9);
if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
{
return -1;
}
return 0;
}
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
uint8_t *pk,
uint8_t *sk
) {
unsigned char seed[48];
/*
* Generate a random seed.
*/
randombytes(seed, sizeof seed);
return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed);
}
/*
* Compute the signature. nonce[] receives the nonce and must have length
* NONCELEN bytes. sigbuf[] receives the signature value (without nonce
* or header byte), with *sigbuflen providing the maximum value length and
* receiving the actual value length.
*
* If a signature could be computed but not encoded because it would
* exceed the output buffer size, then a new signature is computed. If
* the provided buffer size is too low, this could loop indefinitely, so
* the caller must provide a size that can accommodate signatures with a
* large enough probability.
*
* Return value: 0 on success, -1 on error.
*/
static int do_sign(
uint8_t *nonce,
uint8_t *sigbuf,
size_t *sigbuflen,
const uint8_t *m,
size_t mlen,
const uint8_t *sk
) {
union
{
uint8_t b[72 * 512];
uint64_t dummy_u64;
fpr dummy_fpr;
} tmp;
int8_t f[512], g[512], F[512], G[512];
struct
{
int16_t sig[512];
uint16_t hm[512];
} r;
unsigned char seed[48];
inner_shake256_context sc;
rpo128_context rc;
size_t u, v;
/*
* Decode the private key.
*/
if (sk[0] != 0x50 + 9)
{
return -1;
}
u = 1;
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
if (v == 0)
{
return -1;
}
u += v;
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9],
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
if (v == 0)
{
return -1;
}
u += v;
v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode(
F, 9, PQCLEAN_FALCON512_CLEAN_max_FG_bits[9],
sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u);
if (v == 0)
{
return -1;
}
u += v;
if (u != PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES)
{
return -1;
}
if (!PQCLEAN_FALCON512_CLEAN_complete_private(G, f, g, F, 9, tmp.b))
{
return -1;
}
/*
* Create a random nonce (40 bytes).
*/
randombytes(nonce, NONCELEN);
/* ==== Start: Deviation from the reference implementation ================================= */
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
// the conversion to field elements succeeds
uint8_t buffer[64];
memset(buffer, 0, 64);
for (size_t i = 0; i < 8; i++)
{
buffer[8 * i] = nonce[5 * i];
buffer[8 * i + 1] = nonce[5 * i + 1];
buffer[8 * i + 2] = nonce[5 * i + 2];
buffer[8 * i + 3] = nonce[5 * i + 3];
buffer[8 * i + 4] = nonce[5 * i + 4];
}
/*
* Hash message nonce + message into a vector.
*/
rpo128_init(&rc);
rpo128_absorb(&rc, buffer, NONCELEN + 24);
rpo128_absorb(&rc, m, mlen);
rpo128_finalize(&rc);
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, r.hm, 9);
rpo128_release(&rc);
/* ==== End: Deviation from the reference implementation =================================== */
/*
* Initialize a RNG.
*/
randombytes(seed, sizeof seed);
inner_shake256_init(&sc);
inner_shake256_inject(&sc, seed, sizeof seed);
inner_shake256_flip(&sc);
/*
* Compute and return the signature. This loops until a signature
* value is found that fits in the provided buffer.
*/
for (;;)
{
PQCLEAN_FALCON512_CLEAN_sign_dyn(r.sig, &sc, f, g, F, G, r.hm, 9, tmp.b);
v = PQCLEAN_FALCON512_CLEAN_comp_encode(sigbuf, *sigbuflen, r.sig, 9);
if (v != 0)
{
inner_shake256_ctx_release(&sc);
*sigbuflen = v;
return 0;
}
}
}
/*
* Verify a signature. The nonce has size NONCELEN bytes. sigbuf[]
* (of size sigbuflen) contains the signature value, not including the
* header byte or nonce. Return value is 0 on success, -1 on error.
*/
static int do_verify(
const uint8_t *nonce,
const uint8_t *sigbuf,
size_t sigbuflen,
const uint8_t *m,
size_t mlen,
const uint8_t *pk
) {
union
{
uint8_t b[2 * 512];
uint64_t dummy_u64;
fpr dummy_fpr;
} tmp;
uint16_t h[512], hm[512];
int16_t sig[512];
rpo128_context rc;
/*
* Decode public key.
*/
if (pk[0] != 0x00 + 9)
{
return -1;
}
if (PQCLEAN_FALCON512_CLEAN_modq_decode(h, 9,
pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
!= PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1)
{
return -1;
}
PQCLEAN_FALCON512_CLEAN_to_ntt_monty(h, 9);
/*
* Decode signature.
*/
if (sigbuflen == 0)
{
return -1;
}
if (PQCLEAN_FALCON512_CLEAN_comp_decode(sig, 9, sigbuf, sigbuflen) != sigbuflen)
{
return -1;
}
/* ==== Start: Deviation from the reference implementation ================================= */
/*
* Hash nonce + message into a vector.
*/
// Transform the nonce into 8 chunks each of size 5 bytes. We do this in order to be sure that
// the conversion to field elements succeeds
uint8_t buffer[64];
memset(buffer, 0, 64);
for (size_t i = 0; i < 8; i++)
{
buffer[8 * i] = nonce[5 * i];
buffer[8 * i + 1] = nonce[5 * i + 1];
buffer[8 * i + 2] = nonce[5 * i + 2];
buffer[8 * i + 3] = nonce[5 * i + 3];
buffer[8 * i + 4] = nonce[5 * i + 4];
}
rpo128_init(&rc);
rpo128_absorb(&rc, buffer, NONCELEN + 24);
rpo128_absorb(&rc, m, mlen);
rpo128_finalize(&rc);
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(&rc, hm, 9);
rpo128_release(&rc);
/* === End: Deviation from the reference implementation ==================================== */
/*
* Verify signature.
*/
if (!PQCLEAN_FALCON512_CLEAN_verify_raw(hm, sig, h, 9, tmp.b))
{
return -1;
}
return 0;
}
/* see falcon.h */
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
uint8_t *sig,
size_t *siglen,
const uint8_t *m,
size_t mlen,
const uint8_t *sk
) {
/*
* The PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES constant is used for
* the signed message object (as produced by crypto_sign())
* and includes a two-byte length value, so we take care here
* to only generate signatures that are two bytes shorter than
* the maximum. This is done to ensure that crypto_sign()
* and crypto_sign_signature() produce the exact same signature
* value, if used on the same message, with the same private key,
* and using the same output from randombytes() (this is for
* reproducibility of tests).
*/
size_t vlen;
vlen = PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES - NONCELEN - 3;
if (do_sign(sig + 1, sig + 1 + NONCELEN, &vlen, m, mlen, sk) < 0)
{
return -1;
}
sig[0] = 0x30 + 9;
*siglen = 1 + NONCELEN + vlen;
return 0;
}
/* see falcon.h */
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
const uint8_t *sig,
size_t siglen,
const uint8_t *m,
size_t mlen,
const uint8_t *pk
) {
if (siglen < 1 + NONCELEN)
{
return -1;
}
if (sig[0] != 0x30 + 9)
{
return -1;
}
return do_verify(sig + 1, sig + 1 + NONCELEN, siglen - 1 - NONCELEN, m, mlen, pk);
}

View File

@@ -1,66 +0,0 @@
#include <stddef.h>
#include <stdint.h>
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES 1281
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES 897
#define PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES 666
/*
* Generate a new key pair. Public key goes into pk[], private key in sk[].
* Key sizes are exact (in bytes):
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
*
* Return value: 0 on success, -1 on error.
*
* Note: This implementation follows the reference implementation in PQClean
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
* verbatim except for the sections that are marked otherwise.
*/
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
uint8_t *pk, uint8_t *sk);
/*
* Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
* Key sizes are exact (in bytes):
* public (pk): PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES
* private (sk): PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES
*
* Return value: 0 on success, -1 on error.
*/
int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
uint8_t *pk, uint8_t *sk, unsigned char *seed);
/*
* Compute a signature on a provided message (m, mlen), with a given
* private key (sk). Signature is written in sig[], with length written
* into *siglen. Signature length is variable; maximum signature length
* (in bytes) is PQCLEAN_FALCON512_CLEAN_CRYPTO_BYTES.
*
* sig[], m[] and sk[] may overlap each other arbitrarily.
*
* Return value: 0 on success, -1 on error.
*
* Note: This implementation follows the reference implementation in PQClean
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
* verbatim except for the sections that are marked otherwise.
*/
int PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
uint8_t *sig, size_t *siglen,
const uint8_t *m, size_t mlen, const uint8_t *sk);
/*
* Verify a signature (sig, siglen) on a message (m, mlen) with a given
* public key (pk).
*
* sig[], m[] and pk[] may overlap each other arbitrarily.
*
* Return value: 0 on success, -1 on error.
*
* Note: This implementation follows the reference implementation in PQClean
* https://github.com/PQClean/PQClean/tree/master/crypto_sign/falcon-512
* verbatim except for the sections that are marked otherwise.
*/
int PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
const uint8_t *sig, size_t siglen,
const uint8_t *m, size_t mlen, const uint8_t *pk);

View File

@@ -1,582 +0,0 @@
/*
* RPO implementation.
*/
#include <stdint.h>
#include <string.h>
#include <stdlib.h>
/* ================================================================================================
* Modular Arithmetic
*/
#define P 0xFFFFFFFF00000001
#define M 12289
// From https://github.com/ncw/iprime/blob/master/mod_math_noasm.go
static uint64_t add_mod_p(uint64_t a, uint64_t b)
{
a = P - a;
uint64_t res = b - a;
if (b < a)
res += P;
return res;
}
static uint64_t sub_mod_p(uint64_t a, uint64_t b)
{
uint64_t r = a - b;
if (a < b)
r += P;
return r;
}
static uint64_t reduce_mod_p(uint64_t b, uint64_t a)
{
uint32_t d = b >> 32,
c = b;
if (a >= P)
a -= P;
a = sub_mod_p(a, c);
a = sub_mod_p(a, d);
a = add_mod_p(a, ((uint64_t)c) << 32);
return a;
}
static uint64_t mult_mod_p(uint64_t x, uint64_t y)
{
uint32_t a = x,
b = x >> 32,
c = y,
d = y >> 32;
/* first synthesize the product using 32*32 -> 64 bit multiplies */
x = b * (uint64_t)c; /* b*c */
y = a * (uint64_t)d; /* a*d */
uint64_t e = a * (uint64_t)c, /* a*c */
f = b * (uint64_t)d, /* b*d */
t;
x += y; /* b*c + a*d */
/* carry? */
if (x < y)
f += 1LL << 32; /* carry into upper 32 bits - can't overflow */
t = x << 32;
e += t; /* a*c + LSW(b*c + a*d) */
/* carry? */
if (e < t)
f += 1; /* carry into upper 64 bits - can't overflow*/
t = x >> 32;
f += t; /* b*d + MSW(b*c + a*d) */
/* can't overflow */
/* now reduce: (b*d + MSW(b*c + a*d), a*c + LSW(b*c + a*d)) */
return reduce_mod_p(f, e);
}
/* ================================================================================================
* RPO128 Permutation
*/
static const uint64_t STATE_WIDTH = 12;
static const uint64_t NUM_ROUNDS = 7;
/*
* MDS matrix
*/
static const uint64_t MDS[12][12] = {
{ 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8 },
{ 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21 },
{ 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22 },
{ 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6 },
{ 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7 },
{ 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9 },
{ 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10 },
{ 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13 },
{ 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26 },
{ 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8 },
{ 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23 },
{ 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7 },
};
/*
* Round constants.
*/
static const uint64_t ARK1[7][12] = {
{
5789762306288267392ULL,
6522564764413701783ULL,
17809893479458208203ULL,
107145243989736508ULL,
6388978042437517382ULL,
15844067734406016715ULL,
9975000513555218239ULL,
3344984123768313364ULL,
9959189626657347191ULL,
12960773468763563665ULL,
9602914297752488475ULL,
16657542370200465908ULL,
},
{
12987190162843096997ULL,
653957632802705281ULL,
4441654670647621225ULL,
4038207883745915761ULL,
5613464648874830118ULL,
13222989726778338773ULL,
3037761201230264149ULL,
16683759727265180203ULL,
8337364536491240715ULL,
3227397518293416448ULL,
8110510111539674682ULL,
2872078294163232137ULL,
},
{
18072785500942327487ULL,
6200974112677013481ULL,
17682092219085884187ULL,
10599526828986756440ULL,
975003873302957338ULL,
8264241093196931281ULL,
10065763900435475170ULL,
2181131744534710197ULL,
6317303992309418647ULL,
1401440938888741532ULL,
8884468225181997494ULL,
13066900325715521532ULL,
},
{
5674685213610121970ULL,
5759084860419474071ULL,
13943282657648897737ULL,
1352748651966375394ULL,
17110913224029905221ULL,
1003883795902368422ULL,
4141870621881018291ULL,
8121410972417424656ULL,
14300518605864919529ULL,
13712227150607670181ULL,
17021852944633065291ULL,
6252096473787587650ULL,
},
{
4887609836208846458ULL,
3027115137917284492ULL,
9595098600469470675ULL,
10528569829048484079ULL,
7864689113198939815ULL,
17533723827845969040ULL,
5781638039037710951ULL,
17024078752430719006ULL,
109659393484013511ULL,
7158933660534805869ULL,
2955076958026921730ULL,
7433723648458773977ULL,
},
{
16308865189192447297ULL,
11977192855656444890ULL,
12532242556065780287ULL,
14594890931430968898ULL,
7291784239689209784ULL,
5514718540551361949ULL,
10025733853830934803ULL,
7293794580341021693ULL,
6728552937464861756ULL,
6332385040983343262ULL,
13277683694236792804ULL,
2600778905124452676ULL,
},
{
7123075680859040534ULL,
1034205548717903090ULL,
7717824418247931797ULL,
3019070937878604058ULL,
11403792746066867460ULL,
10280580802233112374ULL,
337153209462421218ULL,
13333398568519923717ULL,
3596153696935337464ULL,
8104208463525993784ULL,
14345062289456085693ULL,
17036731477169661256ULL,
}};
const uint64_t ARK2[7][12] = {
{
6077062762357204287ULL,
15277620170502011191ULL,
5358738125714196705ULL,
14233283787297595718ULL,
13792579614346651365ULL,
11614812331536767105ULL,
14871063686742261166ULL,
10148237148793043499ULL,
4457428952329675767ULL,
15590786458219172475ULL,
10063319113072092615ULL,
14200078843431360086ULL,
},
{
6202948458916099932ULL,
17690140365333231091ULL,
3595001575307484651ULL,
373995945117666487ULL,
1235734395091296013ULL,
14172757457833931602ULL,
707573103686350224ULL,
15453217512188187135ULL,
219777875004506018ULL,
17876696346199469008ULL,
17731621626449383378ULL,
2897136237748376248ULL,
},
{
8023374565629191455ULL,
15013690343205953430ULL,
4485500052507912973ULL,
12489737547229155153ULL,
9500452585969030576ULL,
2054001340201038870ULL,
12420704059284934186ULL,
355990932618543755ULL,
9071225051243523860ULL,
12766199826003448536ULL,
9045979173463556963ULL,
12934431667190679898ULL,
},
{
18389244934624494276ULL,
16731736864863925227ULL,
4440209734760478192ULL,
17208448209698888938ULL,
8739495587021565984ULL,
17000774922218161967ULL,
13533282547195532087ULL,
525402848358706231ULL,
16987541523062161972ULL,
5466806524462797102ULL,
14512769585918244983ULL,
10973956031244051118ULL,
},
{
6982293561042362913ULL,
14065426295947720331ULL,
16451845770444974180ULL,
7139138592091306727ULL,
9012006439959783127ULL,
14619614108529063361ULL,
1394813199588124371ULL,
4635111139507788575ULL,
16217473952264203365ULL,
10782018226466330683ULL,
6844229992533662050ULL,
7446486531695178711ULL,
},
{
3736792340494631448ULL,
577852220195055341ULL,
6689998335515779805ULL,
13886063479078013492ULL,
14358505101923202168ULL,
7744142531772274164ULL,
16135070735728404443ULL,
12290902521256031137ULL,
12059913662657709804ULL,
16456018495793751911ULL,
4571485474751953524ULL,
17200392109565783176ULL,
},
{
17130398059294018733ULL,
519782857322261988ULL,
9625384390925085478ULL,
1664893052631119222ULL,
7629576092524553570ULL,
3485239601103661425ULL,
9755891797164033838ULL,
15218148195153269027ULL,
16460604813734957368ULL,
9643968136937729763ULL,
3611348709641382851ULL,
18256379591337759196ULL,
},
};
static void apply_sbox(uint64_t *const state)
{
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
uint64_t t2 = mult_mod_p(*(state + i), *(state + i));
uint64_t t4 = mult_mod_p(t2, t2);
*(state + i) = mult_mod_p(*(state + i), mult_mod_p(t2, t4));
}
}
static void apply_mds(uint64_t *state)
{
uint64_t res[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
res[i] = 0;
}
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
for (uint64_t j = 0; j < STATE_WIDTH; j++)
{
res[i] = add_mod_p(res[i], mult_mod_p(MDS[i][j], *(state + j)));
}
}
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
*(state + i) = res[i];
}
}
static void apply_constants(uint64_t *const state, const uint64_t *ark)
{
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
*(state + i) = add_mod_p(*(state + i), *(ark + i));
}
}
static void exp_acc(const uint64_t m, const uint64_t *base, const uint64_t *tail, uint64_t *const res)
{
for (uint64_t i = 0; i < m; i++)
{
for (uint64_t j = 0; j < STATE_WIDTH; j++)
{
if (i == 0)
{
*(res + j) = mult_mod_p(*(base + j), *(base + j));
}
else
{
*(res + j) = mult_mod_p(*(res + j), *(res + j));
}
}
}
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
*(res + i) = mult_mod_p(*(res + i), *(tail + i));
}
}
static void apply_inv_sbox(uint64_t *const state)
{
uint64_t t1[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t1[i] = 0;
}
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t1[i] = mult_mod_p(*(state + i), *(state + i));
}
uint64_t t2[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t2[i] = 0;
}
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t2[i] = mult_mod_p(t1[i], t1[i]);
}
uint64_t t3[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t3[i] = 0;
}
exp_acc(3, t2, t2, t3);
uint64_t t4[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t4[i] = 0;
}
exp_acc(6, t3, t3, t4);
uint64_t tmp[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
tmp[i] = 0;
}
exp_acc(12, t4, t4, tmp);
uint64_t t5[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t5[i] = 0;
}
exp_acc(6, tmp, t3, t5);
uint64_t t6[STATE_WIDTH];
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
t6[i] = 0;
}
exp_acc(31, t5, t5, t6);
for (uint64_t i = 0; i < STATE_WIDTH; i++)
{
uint64_t a = mult_mod_p(mult_mod_p(t6[i], t6[i]), t5[i]);
a = mult_mod_p(a, a);
a = mult_mod_p(a, a);
uint64_t b = mult_mod_p(mult_mod_p(t1[i], t2[i]), *(state + i));
*(state + i) = mult_mod_p(a, b);
}
}
static void apply_round(uint64_t *const state, const uint64_t round)
{
apply_mds(state);
apply_constants(state, ARK1[round]);
apply_sbox(state);
apply_mds(state);
apply_constants(state, ARK2[round]);
apply_inv_sbox(state);
}
static void apply_permutation(uint64_t *state)
{
for (uint64_t i = 0; i < NUM_ROUNDS; i++)
{
apply_round(state, i);
}
}
/* ================================================================================================
* RPO128 implementation. This is supposed to substitute SHAKE256 in the hash-to-point algorithm.
*/
#include "rpo.h"
void rpo128_init(rpo128_context *rc)
{
rc->dptr = 32;
memset(rc->st.A, 0, sizeof rc->st.A);
}
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len)
{
size_t dptr;
dptr = (size_t)rc->dptr;
while (len > 0)
{
size_t clen, u;
/* 136 * 8 = 1088 bit for the rate portion in the case of SHAKE256
* For RPO, this is 64 * 8 = 512 bits
* The capacity for SHAKE256 is at the end while for RPO128 it is at the beginning
*/
clen = 96 - dptr;
if (clen > len)
{
clen = len;
}
for (u = 0; u < clen; u++)
{
rc->st.dbuf[dptr + u] = in[u];
}
dptr += clen;
in += clen;
len -= clen;
if (dptr == 96)
{
apply_permutation(rc->st.A);
dptr = 32;
}
}
rc->dptr = dptr;
}
void rpo128_finalize(rpo128_context *rc)
{
// Set dptr to the end of the buffer, so that first call to extract will call the permutation.
rc->dptr = 96;
}
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len)
{
size_t dptr;
dptr = (size_t)rc->dptr;
while (len > 0)
{
size_t clen;
if (dptr == 96)
{
apply_permutation(rc->st.A);
dptr = 32;
}
clen = 96 - dptr;
if (clen > len)
{
clen = len;
}
len -= clen;
memcpy(out, rc->st.dbuf + dptr, clen);
dptr += clen;
out += clen;
}
rc->dptr = dptr;
}
void rpo128_release(rpo128_context *rc)
{
memset(rc->st.A, 0, sizeof rc->st.A);
rc->dptr = 32;
}
/* ================================================================================================
* Hash-to-Point algorithm implementation based on RPO128
*/
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn)
{
/*
* This implementation avoids the rejection sampling step needed in the
* per-the-spec implementation. It uses a remark in https://falcon-sign.info/falcon.pdf
* page 31, which argues that the current variant is secure for the parameters set by NIST.
* Avoiding the rejection-sampling step leads to an implementation that is constant-time.
* TODO: Check that the current implementation is indeed constant-time.
*/
size_t n;
n = (size_t)1 << logn;
while (n > 0)
{
uint8_t buf[8];
uint64_t w;
rpo128_squeeze(rc, (void *)buf, sizeof buf);
w = ((uint64_t)(buf[7]) << 56) |
((uint64_t)(buf[6]) << 48) |
((uint64_t)(buf[5]) << 40) |
((uint64_t)(buf[4]) << 32) |
((uint64_t)(buf[3]) << 24) |
((uint64_t)(buf[2]) << 16) |
((uint64_t)(buf[1]) << 8) |
((uint64_t)(buf[0]));
w %= M;
*x++ = (uint16_t)w;
n--;
}
}

View File

@@ -1,83 +0,0 @@
#include <stdint.h>
#include <string.h>
/* ================================================================================================
* RPO hashing algorithm related structs and methods.
*/
/*
* RPO128 context.
*
* This structure is used by the hashing API. It is composed of an internal state that can be
* viewed as either:
* 1. 12 field elements in the Miden VM.
* 2. 96 bytes.
*
* The first view is used for the internal state in the context of the RPO hashing algorithm. The
* second view is used for the buffer used to absorb the data to be hashed.
*
* The pointer to the buffer is updated as the data is absorbed.
*
* 'rpo128_context' must be initialized with rpo128_init() before first use.
*/
typedef struct
{
union
{
uint64_t A[12];
uint8_t dbuf[96];
} st;
uint64_t dptr;
} rpo128_context;
/*
* Initializes an RPO state
*/
void rpo128_init(rpo128_context *rc);
/*
* Absorbs an array of bytes of length 'len' into the state.
*/
void rpo128_absorb(rpo128_context *rc, const uint8_t *in, size_t len);
/*
* Squeezes an array of bytes of length 'len' from the state.
*/
void rpo128_squeeze(rpo128_context *rc, uint8_t *out, size_t len);
/*
* Finalizes the state in preparation for squeezing.
*
* This function should be called after all the data has been absorbed.
*
* Note that the current implementation does not perform any sort of padding for domain separation
* purposes. The reason being that, for our purposes, we always perform the following sequence:
* 1. Absorb a Nonce (which is always 40 bytes packed as 8 field elements).
* 2. Absorb the message (which is always 4 field elements).
* 3. Call finalize.
* 4. Squeeze the output.
* 5. Call release.
*/
void rpo128_finalize(rpo128_context *rc);
/*
* Releases the state.
*
* This function should be called after the squeeze operation is finished.
*/
void rpo128_release(rpo128_context *rc);
/* ================================================================================================
* Hash-to-Point algorithm for signature generation and signature verification.
*/
/*
* Hash-to-Point algorithm.
*
* This function generates a point in Z_q[x]/(phi) from a given message.
*
* It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
* representing the point. The coefficients are stored in the array 'x'. The number of coefficients
* is given by 'logn', which must in our case is 512.
*/
void PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(rpo128_context *rc, uint16_t *x, unsigned logn);

View File

@@ -1,189 +0,0 @@
use libc::c_int;
// C IMPLEMENTATION INTERFACE
// ================================================================================================
#[link(name = "rpo_falcon512", kind = "static")]
extern "C" {
/// Generate a new key pair. Public key goes into pk[], private key in sk[].
/// Key sizes are exact (in bytes):
/// - public (pk): 897
/// - private (sk): 1281
///
/// Return value: 0 on success, -1 on error.
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(pk: *mut u8, sk: *mut u8) -> c_int;
/// Generate a new key pair from seed. Public key goes into pk[], private key in sk[].
/// Key sizes are exact (in bytes):
/// - public (pk): 897
/// - private (sk): 1281
///
/// Return value: 0 on success, -1 on error.
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
pk: *mut u8,
sk: *mut u8,
seed: *const u8,
) -> c_int;
/// Compute a signature on a provided message (m, mlen), with a given private key (sk).
/// Signature is written in sig[], with length written into *siglen. Signature length is
/// variable; maximum signature length (in bytes) is 666.
///
/// sig[], m[] and sk[] may overlap each other arbitrarily.
///
/// Return value: 0 on success, -1 on error.
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
sig: *mut u8,
siglen: *mut usize,
m: *const u8,
mlen: usize,
sk: *const u8,
) -> c_int;
// TEST HELPERS
// --------------------------------------------------------------------------------------------
/// Verify a signature (sig, siglen) on a message (m, mlen) with a given public key (pk).
///
/// sig[], m[] and pk[] may overlap each other arbitrarily.
///
/// Return value: 0 on success, -1 on error.
#[cfg(test)]
pub fn PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
sig: *const u8,
siglen: usize,
m: *const u8,
mlen: usize,
pk: *const u8,
) -> c_int;
/// Hash-to-Point algorithm.
///
/// This function generates a point in Z_q[x]/(phi) from a given message.
///
/// It takes a finalized rpo128_context as input and it generates the coefficients of the polynomial
/// representing the point. The coefficients are stored in the array 'x'. The number of coefficients
/// is given by 'logn', which must in our case is 512.
#[cfg(test)]
pub fn PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
rc: *mut Rpo128Context,
x: *mut u16,
logn: usize,
);
#[cfg(test)]
pub fn rpo128_init(sc: *mut Rpo128Context);
#[cfg(test)]
pub fn rpo128_absorb(
sc: *mut Rpo128Context,
data: *const ::std::os::raw::c_void,
len: libc::size_t,
);
#[cfg(test)]
pub fn rpo128_finalize(sc: *mut Rpo128Context);
}
#[repr(C)]
#[cfg(test)]
pub struct Rpo128Context {
pub content: [u64; 13usize],
}
// TESTS
// ================================================================================================
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use crate::dsa::rpo_falcon512::{NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
use rand_utils::{rand_array, rand_value, rand_vector};
#[test]
fn falcon_ffi() {
unsafe {
//let mut rng = rand::thread_rng();
// --- generate a key pair from a seed ----------------------------
let mut pk = [0u8; PK_LEN];
let mut sk = [0u8; SK_LEN];
let seed: [u8; NONCE_LEN] = rand_array();
assert_eq!(
0,
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
pk.as_mut_ptr(),
sk.as_mut_ptr(),
seed.as_ptr()
)
);
// --- sign a message and make sure it verifies -------------------
let mlen: usize = rand_value::<u16>() as usize;
let msg: Vec<u8> = rand_vector(mlen);
let mut detached_sig = [0u8; NONCE_LEN + SIG_LEN];
let mut siglen = 0;
assert_eq!(
0,
PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
detached_sig.as_mut_ptr(),
&mut siglen as *mut usize,
msg.as_ptr(),
msg.len(),
sk.as_ptr()
)
);
assert_eq!(
0,
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
detached_sig.as_ptr(),
siglen,
msg.as_ptr(),
msg.len(),
pk.as_ptr()
)
);
// --- check verification of different signature ------------------
assert_eq!(
-1,
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
detached_sig.as_ptr(),
siglen,
msg.as_ptr(),
msg.len() - 1,
pk.as_ptr()
)
);
// --- check verification against a different pub key -------------
let mut pk_alt = [0u8; PK_LEN];
let mut sk_alt = [0u8; SK_LEN];
assert_eq!(
0,
PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
pk_alt.as_mut_ptr(),
sk_alt.as_mut_ptr()
)
);
assert_eq!(
-1,
PQCLEAN_FALCON512_CLEAN_crypto_sign_verify_rpo(
detached_sig.as_ptr(),
siglen,
msg.as_ptr(),
msg.len(),
pk_alt.as_ptr()
)
);
}
}
}

View File

@@ -0,0 +1,70 @@
use alloc::vec::Vec;
use num::Zero;
use super::{math::FalconFelt, Nonce, Polynomial, Rpo256, Word, MODULUS, N, ZERO};
// HASH-TO-POINT FUNCTIONS
// ================================================================================================
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
/// nonce using RPO256.
pub fn hash_to_point_rpo256(message: Word, nonce: &Nonce) -> Polynomial<FalconFelt> {
let mut state = [ZERO; Rpo256::STATE_WIDTH];
// absorb the nonce into the state
let nonce_elements = nonce.to_elements();
for (&n, s) in nonce_elements.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
*s = n;
}
Rpo256::apply_permutation(&mut state);
// absorb message into the state
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
*s = m;
}
// squeeze the coefficients of the polynomial
let mut i = 0;
let mut res = [FalconFelt::zero(); N];
for _ in 0..64 {
Rpo256::apply_permutation(&mut state);
for a in &state[Rpo256::RATE_RANGE] {
res[i] = FalconFelt::new((a.as_int() % MODULUS as u64) as i16);
i += 1;
}
}
Polynomial::new(res.to_vec())
}
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
/// nonce using SHAKE256. This is the hash-to-point algorithm used in the reference implementation.
#[allow(dead_code)]
pub fn hash_to_point_shake256(message: &[u8], nonce: &Nonce) -> Polynomial<FalconFelt> {
use sha3::{
digest::{ExtendableOutput, Update, XofReader},
Shake256,
};
let mut data = vec![];
data.extend_from_slice(nonce.as_bytes());
data.extend_from_slice(message);
const K: u32 = (1u32 << 16) / MODULUS as u32;
let mut hasher = Shake256::default();
hasher.update(&data);
let mut reader = hasher.finalize_xof();
let mut coefficients: Vec<FalconFelt> = Vec::with_capacity(N);
while coefficients.len() != N {
let mut randomness = [0u8; 2];
reader.read(&mut randomness);
let t = ((randomness[0] as u32) << 8) | (randomness[1] as u32);
if t < K * MODULUS as u32 {
coefficients.push(FalconFelt::new((t % MODULUS as u32) as i16));
}
}
Polynomial { coefficients }
}

View File

@@ -1,227 +0,0 @@
use super::{
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconError, Polynomial,
PublicKeyBytes, Rpo256, SecretKeyBytes, Serializable, Signature, Word,
};
#[cfg(feature = "std")]
use super::{ffi, NonceBytes, StarkField, NONCE_LEN, PK_LEN, SIG_LEN, SK_LEN};
// PUBLIC KEY
// ================================================================================================
/// A public key for verifying signatures.
///
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
/// the polynomial representing the raw bytes of the expanded public key.
///
/// For Falcon-512, the first byte of the expanded public key is always equal to log2(512) i.e., 9.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PublicKey(Word);
impl PublicKey {
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
///
/// # Errors
/// Returns an error if the decoding of the public key fails.
pub fn new(pk: PublicKeyBytes) -> Result<Self, FalconError> {
let h = Polynomial::from_pub_key(&pk)?;
let pk_felts = h.to_elements();
let pk_digest = Rpo256::hash_elements(&pk_felts).into();
Ok(Self(pk_digest))
}
/// Verifies the provided signature against provided message and this public key.
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
signature.verify(message, self.0)
}
}
impl From<PublicKey> for Word {
fn from(key: PublicKey) -> Self {
key.0
}
}
// KEY PAIR
// ================================================================================================
/// A key pair (public and secret keys) for signing messages.
///
/// The secret key is a byte array of length [PK_LEN].
/// The public key is a byte array of length [SK_LEN].
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct KeyPair {
public_key: PublicKeyBytes,
secret_key: SecretKeyBytes,
}
#[allow(clippy::new_without_default)]
impl KeyPair {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Generates a (public_key, secret_key) key pair from OS-provided randomness.
///
/// # Errors
/// Returns an error if key generation fails.
#[cfg(feature = "std")]
pub fn new() -> Result<Self, FalconError> {
let mut public_key = [0u8; PK_LEN];
let mut secret_key = [0u8; SK_LEN];
let res = unsafe {
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo(
public_key.as_mut_ptr(),
secret_key.as_mut_ptr(),
)
};
if res == 0 {
Ok(Self { public_key, secret_key })
} else {
Err(FalconError::KeyGenerationFailed)
}
}
/// Generates a (public_key, secret_key) key pair from the provided seed.
///
/// # Errors
/// Returns an error if key generation fails.
#[cfg(feature = "std")]
pub fn from_seed(seed: &NonceBytes) -> Result<Self, FalconError> {
let mut public_key = [0u8; PK_LEN];
let mut secret_key = [0u8; SK_LEN];
let res = unsafe {
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(
public_key.as_mut_ptr(),
secret_key.as_mut_ptr(),
seed.as_ptr(),
)
};
if res == 0 {
Ok(Self { public_key, secret_key })
} else {
Err(FalconError::KeyGenerationFailed)
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the public key corresponding to this key pair.
pub fn public_key(&self) -> PublicKey {
// TODO: memoize public key commitment as computing it requires quite a bit of hashing.
// expect() is fine here because we assume that the key pair was constructed correctly.
PublicKey::new(self.public_key).expect("invalid key pair")
}
/// Returns the expanded public key corresponding to this key pair.
pub fn expanded_public_key(&self) -> PublicKeyBytes {
self.public_key
}
// SIGNATURE GENERATION
// --------------------------------------------------------------------------------------------
/// Signs a message with a secret key and a seed.
///
/// # Errors
/// Returns an error of signature generation fails.
#[cfg(feature = "std")]
pub fn sign(&self, message: Word) -> Result<Signature, FalconError> {
let msg = message.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
let msg_len = msg.len();
let mut sig = [0_u8; SIG_LEN + NONCE_LEN];
let mut sig_len: usize = 0;
let res = unsafe {
ffi::PQCLEAN_FALCON512_CLEAN_crypto_sign_signature_rpo(
sig.as_mut_ptr(),
&mut sig_len as *mut usize,
msg.as_ptr(),
msg_len,
self.secret_key.as_ptr(),
)
};
if res == 0 {
Ok(Signature { sig, pk: self.public_key })
} else {
Err(FalconError::SigGenerationFailed)
}
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for KeyPair {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.public_key);
target.write_bytes(&self.secret_key);
}
}
impl Deserializable for KeyPair {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let public_key: PublicKeyBytes = source.read_array()?;
let secret_key: SecretKeyBytes = source.read_array()?;
Ok(Self { public_key, secret_key })
}
}
// TESTS
// ================================================================================================
#[cfg(all(test, feature = "std"))]
mod tests {
use super::{super::Felt, KeyPair, NonceBytes, Word};
use rand_utils::{rand_array, rand_vector};
#[test]
fn test_falcon_verification() {
// generate random keys
let keys = KeyPair::new().unwrap();
let pk = keys.public_key();
// sign a random message
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
let signature = keys.sign(message);
// make sure the signature verifies correctly
assert!(pk.verify(message, signature.as_ref().unwrap()));
// a signature should not verify against a wrong message
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
// a signature should not verify against a wrong public key
let keys2 = KeyPair::new().unwrap();
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
}
#[test]
fn test_falcon_verification_from_seed() {
// generate keys from a random seed
let seed: NonceBytes = rand_array();
let keys = KeyPair::from_seed(&seed).unwrap();
let pk = keys.public_key();
// sign a random message
let message: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
let signature = keys.sign(message);
// make sure the signature verifies correctly
assert!(pk.verify(message, signature.as_ref().unwrap()));
// a signature should not verify against a wrong message
let message2: Word = rand_vector::<Felt>(4).try_into().expect("Should not fail.");
assert!(!pk.verify(message2, signature.as_ref().unwrap()));
// a signature should not verify against a wrong public key
let keys2 = KeyPair::new().unwrap();
assert!(!keys2.public_key().verify(message, signature.as_ref().unwrap()))
}
}

View File

@@ -0,0 +1,55 @@
use super::{
math::{FalconFelt, Polynomial},
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Serializable, Signature,
Word,
};
mod public_key;
pub use public_key::{PubKeyPoly, PublicKey};
mod secret_key;
pub use secret_key::SecretKey;
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use winter_math::FieldElement;
use winter_utils::{Deserializable, Serializable};
use crate::{dsa::rpo_falcon512::SecretKey, Word, ONE};
#[test]
fn test_falcon_verification() {
let seed = [0_u8; 32];
let mut rng = ChaCha20Rng::from_seed(seed);
// generate random keys
let sk = SecretKey::with_rng(&mut rng);
let pk = sk.public_key();
// test secret key serialization/deserialization
let mut buffer = vec![];
sk.write_into(&mut buffer);
let sk_deserialized = SecretKey::read_from_bytes(&buffer).unwrap();
assert_eq!(sk.short_lattice_basis(), sk_deserialized.short_lattice_basis());
// sign a random message
let message: Word = [ONE; 4];
let signature = sk.sign_with_rng(message, &mut rng);
// make sure the signature verifies correctly
assert!(pk.verify(message, &signature));
// a signature should not verify against a wrong message
let message2: Word = [ONE.double(); 4];
assert!(!pk.verify(message2, &signature));
// a signature should not verify against a wrong public key
let sk2 = SecretKey::with_rng(&mut rng);
assert!(!sk2.public_key().verify(message, &signature))
}
}

View File

@@ -0,0 +1,139 @@
use alloc::string::ToString;
use core::ops::Deref;
use num::Zero;
use super::{
super::{Rpo256, LOG_N, N, PK_LEN},
ByteReader, ByteWriter, Deserializable, DeserializationError, FalconFelt, Felt, Polynomial,
Serializable, Signature, Word,
};
use crate::dsa::rpo_falcon512::FALCON_ENCODING_BITS;
// PUBLIC KEY
// ================================================================================================
/// A public key for verifying signatures.
///
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the coefficients of
/// the polynomial representing the raw bytes of the expanded public key. The hash is computed
/// using Rpo256.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PublicKey(Word);
impl PublicKey {
/// Returns a new [PublicKey] which is a commitment to the provided expanded public key.
pub fn new(pub_key: Word) -> Self {
Self(pub_key)
}
/// Verifies the provided signature against provided message and this public key.
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
signature.verify(message, self.0)
}
}
impl From<PubKeyPoly> for PublicKey {
fn from(pk_poly: PubKeyPoly) -> Self {
let pk_felts: Polynomial<Felt> = pk_poly.0.into();
let pk_digest = Rpo256::hash_elements(&pk_felts.coefficients).into();
Self(pk_digest)
}
}
impl From<PublicKey> for Word {
fn from(key: PublicKey) -> Self {
key.0
}
}
// PUBLIC KEY POLYNOMIAL
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PubKeyPoly(pub Polynomial<FalconFelt>);
impl Deref for PubKeyPoly {
type Target = Polynomial<FalconFelt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Polynomial<FalconFelt>> for PubKeyPoly {
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
Self(pk_poly)
}
}
impl Serializable for &PubKeyPoly {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let mut buf = [0_u8; PK_LEN];
buf[0] = LOG_N;
let mut acc = 0_u32;
let mut acc_len: u32 = 0;
let mut input_pos = 1;
for c in self.0.coefficients.iter() {
let c = c.value();
acc = (acc << FALCON_ENCODING_BITS) | c as u32;
acc_len += FALCON_ENCODING_BITS;
while acc_len >= 8 {
acc_len -= 8;
buf[input_pos] = (acc >> acc_len) as u8;
input_pos += 1;
}
}
if acc_len > 0 {
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
}
target.write(buf);
}
}
impl Deserializable for PubKeyPoly {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let buf = source.read_array::<PK_LEN>()?;
if buf[0] != LOG_N {
return Err(DeserializationError::InvalidValue(format!(
"Failed to decode public key: expected the first byte to be {LOG_N} but was {}",
buf[0]
)));
}
let mut acc = 0_u32;
let mut acc_len = 0;
let mut output = [FalconFelt::zero(); N];
let mut output_idx = 0;
for &byte in buf.iter().skip(1) {
acc = (acc << 8) | (byte as u32);
acc_len += 8;
if acc_len >= FALCON_ENCODING_BITS {
acc_len -= FALCON_ENCODING_BITS;
let w = (acc >> acc_len) & 0x3fff;
let element = w.try_into().map_err(|err| {
DeserializationError::InvalidValue(format!(
"Failed to decode public key: {err}"
))
})?;
output[output_idx] = element;
output_idx += 1;
}
}
if (acc & ((1u32 << acc_len) - 1)) == 0 {
Ok(Polynomial::new(output.to_vec()).into())
} else {
Err(DeserializationError::InvalidValue(
"Failed to decode public key: input not fully consumed".to_string(),
))
}
}
}

View File

@@ -0,0 +1,401 @@
use alloc::{string::ToString, vec::Vec};
use num::Complex;
#[cfg(not(feature = "std"))]
use num::Float;
use num_complex::Complex64;
use rand::Rng;
use super::{
super::{
math::{ffldl, ffsampling, gram, normalize_tree, FalconFelt, FastFft, LdlTree, Polynomial},
signature::SignaturePoly,
ByteReader, ByteWriter, Deserializable, DeserializationError, Nonce, Serializable,
ShortLatticeBasis, Signature, Word, MODULUS, N, SIGMA, SIG_L2_BOUND,
},
PubKeyPoly, PublicKey,
};
use crate::dsa::rpo_falcon512::{
hash_to_point::hash_to_point_rpo256, math::ntru_gen, SIG_NONCE_LEN, SK_LEN,
};
// CONSTANTS
// ================================================================================================
const WIDTH_BIG_POLY_COEFFICIENT: usize = 8;
const WIDTH_SMALL_POLY_COEFFICIENT: usize = 6;
// SECRET KEY
// ================================================================================================
/// Represents the secret key for Falcon DSA.
///
/// The secret key is a quadruple [[g, -f], [G, -F]] of polynomials with integer coefficients. Each
/// polynomial is of degree at most N = 512 and computations with these polynomials is done modulo
/// the monic irreducible polynomial ϕ = x^N + 1. The secret key is a basis for a lattice and has
/// the property of being short with respect to a certain norm and an upper bound appropriate for
/// a given security parameter. The public key on the other hand is another basis for the same
/// lattice and can be described by a single polynomial h with integer coefficients modulo ϕ.
/// The two keys are related by the following relation:
///
/// 1. h = g /f [mod ϕ][mod p]
/// 2. f.G - g.F = p [mod ϕ]
///
/// where p = 12289 is the Falcon prime. Equation 2 is called the NTRU equation.
/// The secret key is generated by first sampling a random pair (f, g) of polynomials using
/// an appropriate distribution that yields short but not too short polynomials with integer
/// coefficients modulo ϕ. The NTRU equation is then used to find a matching pair (F, G).
/// The public key is then derived from the secret key using equation 1.
///
/// To allow for fast signature generation, the secret key is pre-processed into a more suitable
/// form, called the LDL tree, and this allows for fast sampling of short vectors in the lattice
/// using Fast Fourier sampling during signature generation (ffSampling algorithm 11 in [1]).
///
/// [1]: https://falcon-sign.info/falcon.pdf
#[derive(Debug, Clone)]
pub struct SecretKey {
secret_key: ShortLatticeBasis,
tree: LdlTree,
}
#[allow(clippy::new_without_default)]
impl SecretKey {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Generates a secret key from OS-provided randomness.
#[cfg(feature = "std")]
pub fn new() -> Self {
use rand::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::from_entropy();
Self::with_rng(&mut rng)
}
/// Generates a secret_key using the provided random number generator `Rng`.
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
let basis = ntru_gen(N, rng);
Self::from_short_lattice_basis(basis)
}
/// Given a short basis [[g, -f], [G, -F]], computes the normalized LDL tree i.e., Falcon tree.
fn from_short_lattice_basis(basis: ShortLatticeBasis) -> SecretKey {
// FFT each polynomial of the short basis.
let basis_fft = to_complex_fft(&basis);
// compute the Gram matrix.
let gram_fft = gram(basis_fft);
// construct the LDL tree of the Gram matrix.
let mut tree = ffldl(gram_fft);
// normalize the leaves of the LDL tree.
normalize_tree(&mut tree, SIGMA);
Self { secret_key: basis, tree }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the polynomials of the short lattice basis of this secret key.
pub fn short_lattice_basis(&self) -> &ShortLatticeBasis {
&self.secret_key
}
/// Returns the public key corresponding to this secret key.
pub fn public_key(&self) -> PublicKey {
self.compute_pub_key_poly().into()
}
/// Returns the LDL tree associated to this secret key.
pub fn tree(&self) -> &LdlTree {
&self.tree
}
// SIGNATURE GENERATION
// --------------------------------------------------------------------------------------------
/// Signs a message with this secret key.
#[cfg(feature = "std")]
pub fn sign(&self, message: Word) -> Signature {
use rand::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::from_entropy();
self.sign_with_rng(message, &mut rng)
}
/// Signs a message with the secret key relying on the provided randomness generator.
pub fn sign_with_rng<R: Rng>(&self, message: Word, rng: &mut R) -> Signature {
let mut nonce_bytes = [0u8; SIG_NONCE_LEN];
rng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::new(nonce_bytes);
let h = self.compute_pub_key_poly();
let c = hash_to_point_rpo256(message, &nonce);
let s2 = self.sign_helper(c, rng);
Signature::new(nonce, h, s2)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Derives the public key corresponding to this secret key using h = g /f [mod ϕ][mod p].
pub fn compute_pub_key_poly(&self) -> PubKeyPoly {
let g: Polynomial<FalconFelt> = self.secret_key[0].clone().into();
let g_fft = g.fft();
let minus_f: Polynomial<FalconFelt> = self.secret_key[1].clone().into();
let f = -minus_f;
let f_fft = f.fft();
let h_fft = g_fft.hadamard_div(&f_fft);
h_fft.ifft().into()
}
/// Signs a message polynomial with the secret key.
///
/// Takes a randomness generator implementing `Rng` and message polynomial representing `c`
/// the hash-to-point of the message to be signed. It outputs a signature polynomial `s2`.
fn sign_helper<R: Rng>(&self, c: Polynomial<FalconFelt>, rng: &mut R) -> SignaturePoly {
let one_over_q = 1.0 / (MODULUS as f64);
let c_over_q_fft = c.map(|cc| Complex::new(one_over_q * cc.value() as f64, 0.0)).fft();
// B = [[FFT(g), -FFT(f)], [FFT(G), -FFT(F)]]
let [g_fft, minus_f_fft, big_g_fft, minus_big_f_fft] = to_complex_fft(&self.secret_key);
let t0 = c_over_q_fft.hadamard_mul(&minus_big_f_fft);
let t1 = -c_over_q_fft.hadamard_mul(&minus_f_fft);
loop {
let bold_s = loop {
let z = ffsampling(&(t0.clone(), t1.clone()), &self.tree, rng);
let t0_min_z0 = t0.clone() - z.0;
let t1_min_z1 = t1.clone() - z.1;
// s = (t-z) * B
let s0 = t0_min_z0.hadamard_mul(&g_fft) + t1_min_z1.hadamard_mul(&big_g_fft);
let s1 =
t0_min_z0.hadamard_mul(&minus_f_fft) + t1_min_z1.hadamard_mul(&minus_big_f_fft);
// compute the norm of (s0||s1) and note that they are in FFT representation
let length_squared: f64 =
(s0.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>()
+ s1.coefficients.iter().map(|a| (a * a.conj()).re).sum::<f64>())
/ (N as f64);
if length_squared > (SIG_L2_BOUND as f64) {
continue;
}
break [-s0, s1];
};
let s2 = bold_s[1].ifft();
let s2_coef: [i16; N] = s2
.coefficients
.iter()
.map(|a| a.re.round() as i16)
.collect::<Vec<i16>>()
.try_into()
.expect("The number of coefficients should be equal to N");
if let Ok(s2) = SignaturePoly::try_from(&s2_coef) {
return s2;
}
}
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for SecretKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let basis = &self.secret_key;
// header
let n = basis[0].coefficients.len();
let l = n.checked_ilog2().unwrap() as u8;
let header: u8 = (5 << 4) | l;
let neg_f = &basis[1];
let g = &basis[0];
let neg_big_f = &basis[3];
let mut buffer = Vec::with_capacity(1281);
buffer.push(header);
let f_i8: Vec<i8> = neg_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let f_i8_encoded = encode_i8(&f_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&f_i8_encoded);
let g_i8: Vec<i8> = g
.coefficients
.iter()
.map(|&a| FalconFelt::new(a).balanced_value() as i8)
.collect();
let g_i8_encoded = encode_i8(&g_i8, WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&g_i8_encoded);
let big_f_i8: Vec<i8> = neg_big_f
.coefficients
.iter()
.map(|&a| FalconFelt::new(-a).balanced_value() as i8)
.collect();
let big_f_i8_encoded = encode_i8(&big_f_i8, WIDTH_BIG_POLY_COEFFICIENT).unwrap();
buffer.extend_from_slice(&big_f_i8_encoded);
target.write_bytes(&buffer);
}
}
impl Deserializable for SecretKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let byte_vector: [u8; SK_LEN] = source.read_array()?;
// check length
if byte_vector.len() < 2 {
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
}
// read fields
let header = byte_vector[0];
// check fixed bits in header
if (header >> 4) != 5 {
return Err(DeserializationError::InvalidValue("Invalid header format".to_string()));
}
// check log n
let logn = (header & 15) as usize;
let n = 1 << logn;
// match against const variant generic parameter
if n != N {
return Err(DeserializationError::InvalidValue(
"Unsupported Falcon DSA variant".to_string(),
));
}
if byte_vector.len() != SK_LEN {
return Err(DeserializationError::InvalidValue("Invalid encoding length: Failed to decode as length is different from the one expected".to_string()));
}
let chunk_size_f = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
let chunk_size_g = ((n * WIDTH_SMALL_POLY_COEFFICIENT) + 7) >> 3;
let chunk_size_big_f = ((n * WIDTH_BIG_POLY_COEFFICIENT) + 7) >> 3;
let f = decode_i8(&byte_vector[1..chunk_size_f + 1], WIDTH_SMALL_POLY_COEFFICIENT).unwrap();
let g = decode_i8(
&byte_vector[chunk_size_f + 1..(chunk_size_f + chunk_size_g + 1)],
WIDTH_SMALL_POLY_COEFFICIENT,
)
.unwrap();
let big_f = decode_i8(
&byte_vector[(chunk_size_f + chunk_size_g + 1)
..(chunk_size_f + chunk_size_g + chunk_size_big_f + 1)],
WIDTH_BIG_POLY_COEFFICIENT,
)
.unwrap();
let f = Polynomial::new(f.iter().map(|&c| FalconFelt::new(c.into())).collect());
let g = Polynomial::new(g.iter().map(|&c| FalconFelt::new(c.into())).collect());
let big_f = Polynomial::new(big_f.iter().map(|&c| FalconFelt::new(c.into())).collect());
// big_g * f - g * big_f = p (mod X^n + 1)
let big_g = g.fft().hadamard_div(&f.fft()).hadamard_mul(&big_f.fft()).ifft();
let basis = [
g.map(|f| f.balanced_value()),
-f.map(|f| f.balanced_value()),
big_g.map(|f| f.balanced_value()),
-big_f.map(|f| f.balanced_value()),
];
Ok(Self::from_short_lattice_basis(basis))
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Computes the complex FFT of the secret key polynomials.
fn to_complex_fft(basis: &[Polynomial<i16>; 4]) -> [Polynomial<Complex<f64>>; 4] {
let [g, f, big_g, big_f] = basis.clone();
let g_fft = g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
let minus_f_fft = f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
let big_g_fft = big_g.map(|cc| Complex64::new(*cc as f64, 0.0)).fft();
let minus_big_f_fft = big_f.map(|cc| -Complex64::new(*cc as f64, 0.0)).fft();
[g_fft, minus_f_fft, big_g_fft, minus_big_f_fft]
}
/// Encodes a sequence of signed integers such that each integer x satisfies |x| < 2^(bits-1)
/// for a given parameter bits. bits can take either the value 6 or 8.
pub fn encode_i8(x: &[i8], bits: usize) -> Option<Vec<u8>> {
let maxv = (1 << (bits - 1)) - 1_usize;
let maxv = maxv as i8;
let minv = -maxv;
for &c in x {
if c > maxv || c < minv {
return None;
}
}
let out_len = ((N * bits) + 7) >> 3;
let mut buf = vec![0_u8; out_len];
let mut acc = 0_u32;
let mut acc_len = 0;
let mask = ((1_u16 << bits) - 1) as u8;
let mut input_pos = 0;
for &c in x {
acc = (acc << bits) | (c as u8 & mask) as u32;
acc_len += bits;
while acc_len >= 8 {
acc_len -= 8;
buf[input_pos] = (acc >> acc_len) as u8;
input_pos += 1;
}
}
if acc_len > 0 {
buf[input_pos] = (acc >> (8 - acc_len)) as u8;
}
Some(buf)
}
/// Decodes a sequence of bytes into a sequence of signed integers such that each integer x
/// satisfies |x| < 2^(bits-1) for a given parameter bits. bits can take either the value 6 or 8.
pub fn decode_i8(buf: &[u8], bits: usize) -> Option<Vec<i8>> {
let mut x = [0_i8; N];
let mut i = 0;
let mut j = 0;
let mut acc = 0_u32;
let mut acc_len = 0;
let mask = (1_u32 << bits) - 1;
let a = (1 << bits) as u8;
let b = ((1 << (bits - 1)) - 1) as u8;
while i < N {
acc = (acc << 8) | (buf[j] as u32);
j += 1;
acc_len += 8;
while acc_len >= bits && i < N {
acc_len -= bits;
let w = (acc >> acc_len) & mask;
let w = w as u8;
let z = if w > b { w as i8 - a as i8 } else { w as i8 };
x[i] = z;
i += 1;
}
}
if (acc & ((1u32 << acc_len) - 1)) == 0 {
Some(x.to_vec())
} else {
None
}
}

View File

@@ -0,0 +1,124 @@
use alloc::boxed::Box;
#[cfg(not(feature = "std"))]
use num::Float;
use num::{One, Zero};
use num_complex::{Complex, Complex64};
use rand::Rng;
use super::{fft::FastFft, polynomial::Polynomial, samplerz::sampler_z};
const SIGMIN: f64 = 1.2778336969128337;
/// Computes the Gram matrix. The argument must be a 2x2 matrix
/// whose elements are equal-length vectors of complex numbers,
/// representing polynomials in FFT domain.
pub fn gram(b: [Polynomial<Complex64>; 4]) -> [Polynomial<Complex64>; 4] {
const N: usize = 2;
let mut g: [Polynomial<Complex<f64>>; 4] =
[Polynomial::zero(), Polynomial::zero(), Polynomial::zero(), Polynomial::zero()];
for i in 0..N {
for j in 0..N {
for k in 0..N {
g[N * i + j] = g[N * i + j].clone()
+ b[N * i + k].hadamard_mul(&b[N * j + k].map(|c| c.conj()));
}
}
}
g
}
/// Computes the LDL decomposition of a 2x2 matrix G such that
/// L D L* = G
/// where D is diagonal, and L is lower-triangular. The elements of the matrices are in FFT domain.
pub fn ldl(
g: [Polynomial<Complex64>; 4],
) -> ([Polynomial<Complex64>; 4], [Polynomial<Complex64>; 4]) {
let zero = Polynomial::<Complex64>::one();
let one = Polynomial::<Complex64>::zero();
let l10 = g[2].hadamard_div(&g[0]);
let bc = l10.map(|c| c * c.conj());
let abc = g[0].hadamard_mul(&bc);
let d11 = g[3].clone() - abc;
let l = [one.clone(), zero.clone(), l10.clone(), one];
let d = [g[0].clone(), zero.clone(), zero, d11];
(l, d)
}
#[derive(Debug, Clone)]
pub enum LdlTree {
Branch(Polynomial<Complex64>, Box<LdlTree>, Box<LdlTree>),
Leaf([Complex64; 2]),
}
/// Computes the LDL Tree of G. Corresponds to Algorithm 9 of the specification [1, p.37].
/// The argument is a 2x2 matrix of polynomials, given in FFT form.
/// [1]: https://falcon-sign.info/falcon.pdf
pub fn ffldl(gram_matrix: [Polynomial<Complex64>; 4]) -> LdlTree {
let n = gram_matrix[0].coefficients.len();
let (l, d) = ldl(gram_matrix);
if n > 2 {
let (d00, d01) = d[0].split_fft();
let (d10, d11) = d[3].split_fft();
let g0 = [d00.clone(), d01.clone(), d01.map(|c| c.conj()), d00];
let g1 = [d10.clone(), d11.clone(), d11.map(|c| c.conj()), d10];
LdlTree::Branch(l[2].clone(), Box::new(ffldl(g0)), Box::new(ffldl(g1)))
} else {
LdlTree::Branch(
l[2].clone(),
Box::new(LdlTree::Leaf(d[0].clone().coefficients.try_into().unwrap())),
Box::new(LdlTree::Leaf(d[3].clone().coefficients.try_into().unwrap())),
)
}
}
/// Normalizes the leaves of an LDL tree using a given normalization value `sigma`.
pub fn normalize_tree(tree: &mut LdlTree, sigma: f64) {
match tree {
LdlTree::Branch(_ell, left, right) => {
normalize_tree(left, sigma);
normalize_tree(right, sigma);
},
LdlTree::Leaf(vector) => {
vector[0] = Complex::new(sigma / vector[0].re.sqrt(), 0.0);
vector[1] = Complex64::zero();
},
}
}
/// Samples short polynomials using a Falcon tree. Algorithm 11 from the spec [1, p.40].
///
/// [1]: https://falcon-sign.info/falcon.pdf
pub fn ffsampling<R: Rng>(
t: &(Polynomial<Complex64>, Polynomial<Complex64>),
tree: &LdlTree,
mut rng: &mut R,
) -> (Polynomial<Complex64>, Polynomial<Complex64>) {
match tree {
LdlTree::Branch(ell, left, right) => {
let bold_t1 = t.1.split_fft();
let bold_z1 = ffsampling(&bold_t1, right, rng);
let z1 = Polynomial::<Complex64>::merge_fft(&bold_z1.0, &bold_z1.1);
// t0' = t0 + (t1 - z1) * l
let t0_prime = t.0.clone() + (t.1.clone() - z1.clone()).hadamard_mul(ell);
let bold_t0 = t0_prime.split_fft();
let bold_z0 = ffsampling(&bold_t0, left, rng);
let z0 = Polynomial::<Complex64>::merge_fft(&bold_z0.0, &bold_z0.1);
(z0, z1)
},
LdlTree::Leaf(value) => {
let z0 = sampler_z(t.0.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
let z1 = sampler_z(t.1.coefficients[0].re, value[0].re, SIGMIN, &mut rng);
(
Polynomial::new(vec![Complex64::new(z0 as f64, 0.0)]),
Polynomial::new(vec![Complex64::new(z1 as f64, 0.0)]),
)
},
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,174 @@
use alloc::string::String;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use num::{One, Zero};
use super::{fft::CyclotomicFourier, Inverse, MODULUS};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct FalconFelt(u32);
impl FalconFelt {
pub const fn new(value: i16) -> Self {
let gtz_bool = value >= 0;
let gtz_int = gtz_bool as i16;
let gtz_sign = gtz_int - ((!gtz_bool) as i16);
let reduced = gtz_sign * (gtz_sign * value) % MODULUS;
let canonical_representative = (reduced + MODULUS * (1 - gtz_int)) as u32;
FalconFelt(canonical_representative)
}
pub const fn value(&self) -> i16 {
self.0 as i16
}
pub fn balanced_value(&self) -> i16 {
let value = self.value();
let g = (value > ((MODULUS) / 2)) as i16;
value - (MODULUS) * g
}
pub const fn multiply(&self, other: Self) -> Self {
FalconFelt((self.0 * other.0) % MODULUS as u32)
}
}
impl Add for FalconFelt {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn add(self, rhs: Self) -> Self::Output {
let (s, _) = self.0.overflowing_add(rhs.0);
let (d, n) = s.overflowing_sub(MODULUS as u32);
let (r, _) = d.overflowing_add(MODULUS as u32 * (n as u32));
FalconFelt(r)
}
}
impl AddAssign for FalconFelt {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Sub for FalconFelt {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self + -rhs
}
}
impl SubAssign for FalconFelt {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Neg for FalconFelt {
type Output = FalconFelt;
fn neg(self) -> Self::Output {
let is_nonzero = self.0 != 0;
let r = MODULUS as u32 - self.0;
FalconFelt(r * (is_nonzero as u32))
}
}
impl Mul for FalconFelt {
fn mul(self, rhs: Self) -> Self::Output {
FalconFelt((self.0 * rhs.0) % MODULUS as u32)
}
type Output = Self;
}
impl MulAssign for FalconFelt {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Div for FalconFelt {
type Output = FalconFelt;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inverse_or_zero()
}
}
impl DivAssign for FalconFelt {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs
}
}
impl Zero for FalconFelt {
fn zero() -> Self {
FalconFelt::new(0)
}
fn is_zero(&self) -> bool {
self.0 == 0
}
}
impl One for FalconFelt {
fn one() -> Self {
FalconFelt::new(1)
}
}
impl Inverse for FalconFelt {
fn inverse_or_zero(self) -> Self {
// q-2 = 0b10 11 11 11 11 11 11
let two = self.multiply(self);
let three = two.multiply(self);
let six = three.multiply(three);
let twelve = six.multiply(six);
let fifteen = twelve.multiply(three);
let thirty = fifteen.multiply(fifteen);
let sixty = thirty.multiply(thirty);
let sixty_three = sixty.multiply(three);
let sixty_three_sq = sixty_three.multiply(sixty_three);
let sixty_three_qu = sixty_three_sq.multiply(sixty_three_sq);
let sixty_three_oc = sixty_three_qu.multiply(sixty_three_qu);
let sixty_three_hx = sixty_three_oc.multiply(sixty_three_oc);
let sixty_three_tt = sixty_three_hx.multiply(sixty_three_hx);
let sixty_three_sf = sixty_three_tt.multiply(sixty_three_tt);
let all_ones = sixty_three_sf.multiply(sixty_three);
let two_e_twelve = all_ones.multiply(self);
let two_e_thirteen = two_e_twelve.multiply(two_e_twelve);
two_e_thirteen.multiply(all_ones)
}
}
impl CyclotomicFourier for FalconFelt {
fn primitive_root_of_unity(n: usize) -> Self {
let log2n = n.ilog2();
assert!(log2n <= 12);
// and 1331 is a twelfth root of unity
let mut a = FalconFelt::new(1331);
let num_squarings = 12 - n.ilog2();
for _ in 0..num_squarings {
a *= a;
}
a
}
}
impl TryFrom<u32> for FalconFelt {
type Error = String;
fn try_from(value: u32) -> Result<Self, Self::Error> {
if value >= MODULUS as u32 {
Err(format!("value {value} is greater than or equal to the field modulus {MODULUS}"))
} else {
Ok(FalconFelt::new(value as i16))
}
}
}

View File

@@ -0,0 +1,322 @@
//! Contains different structs and methods related to the Falcon DSA.
//!
//! It uses and acknowledges the work in:
//!
//! 1. The [reference](https://falcon-sign.info/impl/README.txt.html) implementation by Thomas
//! Pornin.
//! 2. The [Rust](https://github.com/aszepieniec/falcon-rust) implementation by Alan Szepieniec.
use alloc::{string::String, vec::Vec};
use core::ops::MulAssign;
#[cfg(not(feature = "std"))]
use num::Float;
use num::{BigInt, FromPrimitive, One, Zero};
use num_complex::Complex64;
use rand::Rng;
use super::MODULUS;
mod fft;
pub use fft::{CyclotomicFourier, FastFft};
mod field;
pub use field::FalconFelt;
mod ffsampling;
pub use ffsampling::{ffldl, ffsampling, gram, normalize_tree, LdlTree};
mod samplerz;
use self::samplerz::sampler_z;
mod polynomial;
pub use polynomial::Polynomial;
pub trait Inverse: Copy + Zero + MulAssign + One {
/// Gets the inverse of a, or zero if it is zero.
fn inverse_or_zero(self) -> Self;
/// Gets the inverses of a batch of elements, and skip over any that are zero.
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
let mut acc = Self::one();
let mut rp: Vec<Self> = Vec::with_capacity(batch.len());
for batch_item in batch {
if !batch_item.is_zero() {
rp.push(acc);
acc = *batch_item * acc;
} else {
rp.push(Self::zero());
}
}
let mut inv = Self::inverse_or_zero(acc);
for i in (0..batch.len()).rev() {
if !batch[i].is_zero() {
rp[i] *= inv;
inv *= batch[i];
}
}
rp
}
}
impl Inverse for Complex64 {
fn inverse_or_zero(self) -> Self {
let modulus = self.re * self.re + self.im * self.im;
Complex64::new(self.re / modulus, -self.im / modulus)
}
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
batch.iter().map(|&c| Complex64::new(1.0, 0.0) / c).collect()
}
}
impl Inverse for f64 {
fn inverse_or_zero(self) -> Self {
1.0 / self
}
fn batch_inverse_or_zero(batch: &[Self]) -> Vec<Self> {
batch.iter().map(|&c| 1.0 / c).collect()
}
}
/// Samples 4 small polynomials f, g, F, G such that f * G - g * F = q mod (X^n + 1).
/// Algorithm 5 (NTRUgen) of the documentation [1, p.34].
///
/// [1]: https://falcon-sign.info/falcon.pdf
pub(crate) fn ntru_gen<R: Rng>(n: usize, rng: &mut R) -> [Polynomial<i16>; 4] {
loop {
let f = gen_poly(n, rng);
let g = gen_poly(n, rng);
let f_ntt = f.map(|&i| FalconFelt::new(i)).fft();
if f_ntt.coefficients.iter().any(|e| e.is_zero()) {
continue;
}
let gamma = gram_schmidt_norm_squared(&f, &g);
if gamma > 1.3689f64 * (MODULUS as f64) {
continue;
}
if let Some((capital_f, capital_g)) =
ntru_solve(&f.map(|&i| i.into()), &g.map(|&i| i.into()))
{
return [
g,
-f,
capital_g.map(|i| i.try_into().unwrap()),
-capital_f.map(|i| i.try_into().unwrap()),
];
}
}
}
/// Solves the NTRU equation. Given f, g in ZZ[X], find F, G in ZZ[X] such that:
///
/// f G - g F = q mod (X^n + 1)
///
/// Algorithm 6 of the specification [1, p.35].
///
/// [1]: https://falcon-sign.info/falcon.pdf
fn ntru_solve(
f: &Polynomial<BigInt>,
g: &Polynomial<BigInt>,
) -> Option<(Polynomial<BigInt>, Polynomial<BigInt>)> {
let n = f.coefficients.len();
if n == 1 {
let (gcd, u, v) = xgcd(&f.coefficients[0], &g.coefficients[0]);
if gcd != BigInt::one() {
return None;
}
return Some((
(Polynomial::new(vec![-v * BigInt::from_u32(MODULUS as u32).unwrap()])),
Polynomial::new(vec![u * BigInt::from_u32(MODULUS as u32).unwrap()]),
));
}
let f_prime = f.field_norm();
let g_prime = g.field_norm();
let (capital_f_prime, capital_g_prime) = ntru_solve(&f_prime, &g_prime)?;
let capital_f_prime_xsq = capital_f_prime.lift_next_cyclotomic();
let capital_g_prime_xsq = capital_g_prime.lift_next_cyclotomic();
let f_minx = f.galois_adjoint();
let g_minx = g.galois_adjoint();
let mut capital_f = (capital_f_prime_xsq.karatsuba(&g_minx)).reduce_by_cyclotomic(n);
let mut capital_g = (capital_g_prime_xsq.karatsuba(&f_minx)).reduce_by_cyclotomic(n);
match babai_reduce(f, g, &mut capital_f, &mut capital_g) {
Ok(_) => Some((capital_f, capital_g)),
Err(_e) => {
#[cfg(test)]
{
panic!("{}", _e);
}
#[cfg(not(test))]
{
None
}
},
}
}
/// Generates a polynomial of degree at most n-1 whose coefficients are distributed according
/// to a discrete Gaussian with mu = 0 and sigma = 1.17 * sqrt(Q / (2n)).
fn gen_poly<R: Rng>(n: usize, rng: &mut R) -> Polynomial<i16> {
let mu = 0.0;
let sigma_star = 1.43300980528773;
Polynomial {
coefficients: (0..4096)
.map(|_| sampler_z(mu, sigma_star, sigma_star - 0.001, rng))
.collect::<Vec<i16>>()
.chunks(4096 / n)
.map(|ch| ch.iter().sum())
.collect(),
}
}
/// Computes the Gram-Schmidt norm of B = [[g, -f], [G, -F]] from f and g.
/// Corresponds to line 9 in algorithm 5 of the spec [1, p.34]
///
/// [1]: https://falcon-sign.info/falcon.pdf
fn gram_schmidt_norm_squared(f: &Polynomial<i16>, g: &Polynomial<i16>) -> f64 {
let n = f.coefficients.len();
let norm_f_squared = f.l2_norm_squared();
let norm_g_squared = g.l2_norm_squared();
let gamma1 = norm_f_squared + norm_g_squared;
let f_fft = f.map(|i| Complex64::new(*i as f64, 0.0)).fft();
let g_fft = g.map(|i| Complex64::new(*i as f64, 0.0)).fft();
let f_adj_fft = f_fft.map(|c| c.conj());
let g_adj_fft = g_fft.map(|c| c.conj());
let ffgg_fft = f_fft.hadamard_mul(&f_adj_fft) + g_fft.hadamard_mul(&g_adj_fft);
let ffgg_fft_inverse = ffgg_fft.hadamard_inv();
let qf_over_ffgg_fft = f_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
let qg_over_ffgg_fft = g_adj_fft.map(|c| c * (MODULUS as f64)).hadamard_mul(&ffgg_fft_inverse);
let norm_f_over_ffgg_squared =
qf_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
let norm_g_over_ffgg_squared =
qg_over_ffgg_fft.coefficients.iter().map(|c| (c * c.conj()).re).sum::<f64>() / (n as f64);
let gamma2 = norm_f_over_ffgg_squared + norm_g_over_ffgg_squared;
f64::max(gamma1, gamma2)
}
/// Reduces the vector (F,G) relative to (f,g). This method follows the python implementation [1].
/// Note that this algorithm can end up in an infinite loop. (It's one of the things the author
/// would like to fix.) When this happens, control returns an error (hence the return type) and
/// generates another keypair with fresh randomness.
///
/// Algorithm 7 in the spec [2, p.35]
///
/// [1]: https://github.com/tprest/falcon.py
///
/// [2]: https://falcon-sign.info/falcon.pdf
fn babai_reduce(
f: &Polynomial<BigInt>,
g: &Polynomial<BigInt>,
capital_f: &mut Polynomial<BigInt>,
capital_g: &mut Polynomial<BigInt>,
) -> Result<(), String> {
let bitsize = |bi: &BigInt| (bi.bits() + 7) & (u64::MAX ^ 7);
let n = f.coefficients.len();
let size = [
f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
53,
]
.into_iter()
.max()
.unwrap();
let shift = (size as i64) - 53;
let f_adjusted = f
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
.fft();
let g_adjusted = g
.map(|bi| Complex64::new(i64::try_from(bi >> shift).unwrap() as f64, 0.0))
.fft();
let f_star_adjusted = f_adjusted.map(|c| c.conj());
let g_star_adjusted = g_adjusted.map(|c| c.conj());
let denominator_fft =
f_adjusted.hadamard_mul(&f_star_adjusted) + g_adjusted.hadamard_mul(&g_star_adjusted);
let mut counter = 0;
loop {
let capital_size = [
capital_f.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
capital_g.map(bitsize).fold(0, |a, &b| u64::max(a, b)),
53,
]
.into_iter()
.max()
.unwrap();
if capital_size < size {
break;
}
let capital_shift = (capital_size as i64) - 53;
let capital_f_adjusted = capital_f
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
.fft();
let capital_g_adjusted = capital_g
.map(|bi| Complex64::new(i64::try_from(bi >> capital_shift).unwrap() as f64, 0.0))
.fft();
let numerator = capital_f_adjusted.hadamard_mul(&f_star_adjusted)
+ capital_g_adjusted.hadamard_mul(&g_star_adjusted);
let quotient = numerator.hadamard_div(&denominator_fft).ifft();
let k = quotient.map(|f| Into::<BigInt>::into(f.re.round() as i64));
if k.is_zero() {
break;
}
let kf = (k.clone().karatsuba(f))
.reduce_by_cyclotomic(n)
.map(|bi| bi << (capital_size - size));
let kg = (k.clone().karatsuba(g))
.reduce_by_cyclotomic(n)
.map(|bi| bi << (capital_size - size));
*capital_f -= kf;
*capital_g -= kg;
counter += 1;
if counter > 1000 {
// If we get here, that means that (with high likelihood) we are in an
// infinite loop. We know it happens from time to time -- seldomly, but it
// does. It would be nice to fix that! But in order to fix it we need to be
// able to reproduce it, and for that we need test vectors. So print them
// and hope that one day they circle back to the implementor.
return Err(format!("Encountered infinite loop in babai_reduce of falcon-rust.\n\\
Please help the developer(s) fix it! You can do this by sending them the inputs to the function that caused the behavior:\n\\
f: {:?}\n\\
g: {:?}\n\\
capital_f: {:?}\n\\
capital_g: {:?}\n", f.coefficients, g.coefficients, capital_f.coefficients, capital_g.coefficients));
}
}
Ok(())
}
/// Extended Euclidean algorithm for computing the greatest common divisor (g) and
/// Bézout coefficients (u, v) for the relation
///
/// $$ u a + v b = g . $$
///
/// Implementation adapted from Wikipedia [1].
///
/// [1]: https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
fn xgcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
let (mut old_r, mut r) = (a.clone(), b.clone());
let (mut old_s, mut s) = (BigInt::one(), BigInt::zero());
let (mut old_t, mut t) = (BigInt::zero(), BigInt::one());
while r != BigInt::zero() {
let quotient = old_r.clone() / r.clone();
(old_r, r) = (r.clone(), old_r.clone() - quotient.clone() * r);
(old_s, s) = (s.clone(), old_s.clone() - quotient.clone() * s);
(old_t, t) = (t.clone(), old_t.clone() - quotient * t);
}
(old_r, old_s, old_t)
}

View File

@@ -0,0 +1,622 @@
use alloc::vec::Vec;
use core::{
default::Default,
fmt::Debug,
ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
};
use num::{One, Zero};
use super::{field::FalconFelt, Inverse};
use crate::{
dsa::rpo_falcon512::{MODULUS, N},
Felt,
};
#[derive(Debug, Clone, Default)]
pub struct Polynomial<F> {
pub coefficients: Vec<F>,
}
impl<F> Polynomial<F>
where
F: Clone,
{
pub fn new(coefficients: Vec<F>) -> Self {
Self { coefficients }
}
}
impl<
F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone + Inverse,
> Polynomial<F>
{
pub fn hadamard_mul(&self, other: &Self) -> Self {
Polynomial::new(
self.coefficients
.iter()
.zip(other.coefficients.iter())
.map(|(a, b)| *a * *b)
.collect(),
)
}
pub fn hadamard_div(&self, other: &Self) -> Self {
let other_coefficients_inverse = F::batch_inverse_or_zero(&other.coefficients);
Polynomial::new(
self.coefficients
.iter()
.zip(other_coefficients_inverse.iter())
.map(|(a, b)| *a * *b)
.collect(),
)
}
pub fn hadamard_inv(&self) -> Self {
let coefficients_inverse = F::batch_inverse_or_zero(&self.coefficients);
Polynomial::new(coefficients_inverse)
}
}
impl<F: Zero + PartialEq + Clone> Polynomial<F> {
pub fn degree(&self) -> Option<usize> {
if self.coefficients.is_empty() {
return None;
}
let mut max_index = self.coefficients.len() - 1;
while self.coefficients[max_index] == F::zero() {
if let Some(new_index) = max_index.checked_sub(1) {
max_index = new_index;
} else {
return None;
}
}
Some(max_index)
}
pub fn lc(&self) -> F {
match self.degree() {
Some(non_negative_degree) => self.coefficients[non_negative_degree].clone(),
None => F::zero(),
}
}
}
/// The following implementations are specific to cyclotomic polynomial rings,
/// i.e., F\[ X \] / <X^n + 1>, and are used extensively in Falcon.
impl<
F: One
+ Zero
+ Clone
+ Neg<Output = F>
+ MulAssign
+ AddAssign
+ Div<Output = F>
+ Sub<Output = F>
+ PartialEq,
> Polynomial<F>
{
/// Reduce the polynomial by X^n + 1.
pub fn reduce_by_cyclotomic(&self, n: usize) -> Self {
let mut coefficients = vec![F::zero(); n];
let mut sign = -F::one();
for (i, c) in self.coefficients.iter().cloned().enumerate() {
if i % n == 0 {
sign *= -F::one();
}
coefficients[i % n] += sign.clone() * c;
}
Polynomial::new(coefficients)
}
/// Computes the field norm of the polynomial as an element of the cyclotomic ring
/// F\[ X \] / <X^n + 1 > relative to one of half the size, i.e., F\[ X \] / <X^(n/2) + 1> .
///
/// Corresponds to formula 3.25 in the spec [1, p.30].
///
/// [1]: https://falcon-sign.info/falcon.pdf
pub fn field_norm(&self) -> Self {
let n = self.coefficients.len();
let mut f0_coefficients = vec![F::zero(); n / 2];
let mut f1_coefficients = vec![F::zero(); n / 2];
for i in 0..n / 2 {
f0_coefficients[i] = self.coefficients[2 * i].clone();
f1_coefficients[i] = self.coefficients[2 * i + 1].clone();
}
let f0 = Polynomial::new(f0_coefficients);
let f1 = Polynomial::new(f1_coefficients);
let f0_squared = (f0.clone() * f0).reduce_by_cyclotomic(n / 2);
let f1_squared = (f1.clone() * f1).reduce_by_cyclotomic(n / 2);
let x = Polynomial::new(vec![F::zero(), F::one()]);
f0_squared - (x * f1_squared).reduce_by_cyclotomic(n / 2)
}
/// Lifts an element from a cyclotomic polynomial ring to one of double the size.
pub fn lift_next_cyclotomic(&self) -> Self {
let n = self.coefficients.len();
let mut coefficients = vec![F::zero(); n * 2];
for i in 0..n {
coefficients[2 * i] = self.coefficients[i].clone();
}
Self::new(coefficients)
}
/// Computes the galois adjoint of the polynomial in the cyclotomic ring F\[ X \] / < X^n + 1 >
/// , which corresponds to f(x^2).
pub fn galois_adjoint(&self) -> Self {
Self::new(
self.coefficients
.iter()
.enumerate()
.map(|(i, c)| if i % 2 == 0 { c.clone() } else { c.clone().neg() })
.collect(),
)
}
}
impl<F: Clone + Into<f64>> Polynomial<F> {
pub(crate) fn l2_norm_squared(&self) -> f64 {
self.coefficients
.iter()
.map(|i| Into::<f64>::into(i.clone()))
.map(|i| i * i)
.sum::<f64>()
}
}
impl<F> PartialEq for Polynomial<F>
where
F: Zero + PartialEq + Clone + AddAssign,
{
fn eq(&self, other: &Self) -> bool {
if self.is_zero() && other.is_zero() {
true
} else if self.is_zero() || other.is_zero() {
false
} else {
let self_degree = self.degree().unwrap();
let other_degree = other.degree().unwrap();
self.coefficients[0..=self_degree] == other.coefficients[0..=other_degree]
}
}
}
impl<F> Eq for Polynomial<F> where F: Zero + PartialEq + Clone + AddAssign {}
impl<F> Add for &Polynomial<F>
where
F: Add<Output = F> + AddAssign + Clone,
{
type Output = Polynomial<F>;
fn add(self, rhs: Self) -> Self::Output {
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
let mut coefficients = self.coefficients.clone();
for (i, c) in rhs.coefficients.iter().enumerate() {
coefficients[i] += c.clone();
}
coefficients
} else {
let mut coefficients = rhs.coefficients.clone();
for (i, c) in self.coefficients.iter().enumerate() {
coefficients[i] += c.clone();
}
coefficients
};
Self::Output { coefficients }
}
}
impl<F> Add for Polynomial<F>
where
F: Add<Output = F> + AddAssign + Clone,
{
type Output = Polynomial<F>;
fn add(self, rhs: Self) -> Self::Output {
let coefficients = if self.coefficients.len() >= rhs.coefficients.len() {
let mut coefficients = self.coefficients.clone();
for (i, c) in rhs.coefficients.into_iter().enumerate() {
coefficients[i] += c;
}
coefficients
} else {
let mut coefficients = rhs.coefficients.clone();
for (i, c) in self.coefficients.into_iter().enumerate() {
coefficients[i] += c;
}
coefficients
};
Self::Output { coefficients }
}
}
impl<F> AddAssign for Polynomial<F>
where
F: Add<Output = F> + AddAssign + Clone,
{
fn add_assign(&mut self, rhs: Self) {
if self.coefficients.len() >= rhs.coefficients.len() {
for (i, c) in rhs.coefficients.into_iter().enumerate() {
self.coefficients[i] += c;
}
} else {
let mut coefficients = rhs.coefficients.clone();
for (i, c) in self.coefficients.iter().enumerate() {
coefficients[i] += c.clone();
}
self.coefficients = coefficients;
}
}
}
impl<F> Sub for &Polynomial<F>
where
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
{
type Output = Polynomial<F>;
fn sub(self, rhs: Self) -> Self::Output {
self + &(-rhs)
}
}
impl<F> Sub for Polynomial<F>
where
F: Sub<Output = F> + Clone + Neg<Output = F> + Add<Output = F> + AddAssign,
{
type Output = Polynomial<F>;
fn sub(self, rhs: Self) -> Self::Output {
self + (-rhs)
}
}
impl<F> SubAssign for Polynomial<F>
where
F: Add<Output = F> + Neg<Output = F> + AddAssign + Clone + Sub<Output = F>,
{
fn sub_assign(&mut self, rhs: Self) {
self.coefficients = self.clone().sub(rhs).coefficients;
}
}
impl<F: Neg<Output = F> + Clone> Neg for &Polynomial<F> {
type Output = Polynomial<F>;
fn neg(self) -> Self::Output {
Self::Output {
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
}
}
}
impl<F: Neg<Output = F> + Clone> Neg for Polynomial<F> {
type Output = Self;
fn neg(self) -> Self::Output {
Self::Output {
coefficients: self.coefficients.iter().cloned().map(|a| -a).collect(),
}
}
}
impl<F> Mul for &Polynomial<F>
where
F: Add + AddAssign + Mul<Output = F> + Sub<Output = F> + Zero + PartialEq + Clone,
{
type Output = Polynomial<F>;
fn mul(self, other: Self) -> Self::Output {
if self.is_zero() || other.is_zero() {
return Polynomial::<F>::zero();
}
let mut coefficients =
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
for i in 0..self.coefficients.len() {
for j in 0..other.coefficients.len() {
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
}
}
Polynomial { coefficients }
}
}
impl<F> Mul for Polynomial<F>
where
F: Add + AddAssign + Mul<Output = F> + Zero + PartialEq + Clone,
{
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
if self.is_zero() || other.is_zero() {
return Self::zero();
}
let mut coefficients =
vec![F::zero(); self.coefficients.len() + other.coefficients.len() - 1];
for i in 0..self.coefficients.len() {
for j in 0..other.coefficients.len() {
coefficients[i + j] += self.coefficients[i].clone() * other.coefficients[j].clone();
}
}
Self { coefficients }
}
}
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for &Polynomial<F> {
type Output = Polynomial<F>;
fn mul(self, other: F) -> Self::Output {
Polynomial {
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
}
}
}
impl<F: Add + Mul<Output = F> + Zero + Clone> Mul<F> for Polynomial<F> {
type Output = Polynomial<F>;
fn mul(self, other: F) -> Self::Output {
Polynomial {
coefficients: self.coefficients.iter().cloned().map(|i| i * other.clone()).collect(),
}
}
}
impl<F: Mul<Output = F> + Sub<Output = F> + AddAssign + Zero + Div<Output = F> + Clone>
Polynomial<F>
{
/// Multiply two polynomials using Karatsuba's divide-and-conquer algorithm.
pub fn karatsuba(&self, other: &Self) -> Self {
Polynomial::new(vector_karatsuba(&self.coefficients, &other.coefficients))
}
}
impl<F> One for Polynomial<F>
where
F: Clone + One + PartialEq + Zero + AddAssign,
{
fn one() -> Self {
Self { coefficients: vec![F::one()] }
}
}
impl<F> Zero for Polynomial<F>
where
F: Zero + PartialEq + Clone + AddAssign,
{
fn zero() -> Self {
Self { coefficients: vec![] }
}
fn is_zero(&self) -> bool {
self.degree().is_none()
}
}
impl<F: Zero + Clone> Polynomial<F> {
pub fn shift(&self, shamt: usize) -> Self {
Self {
coefficients: [vec![F::zero(); shamt], self.coefficients.clone()].concat(),
}
}
pub fn constant(f: F) -> Self {
Self { coefficients: vec![f] }
}
pub fn map<G: Clone, C: FnMut(&F) -> G>(&self, closure: C) -> Polynomial<G> {
Polynomial::<G>::new(self.coefficients.iter().map(closure).collect())
}
pub fn fold<G, C: FnMut(G, &F) -> G + Clone>(&self, mut initial_value: G, closure: C) -> G {
for c in self.coefficients.iter() {
initial_value = (closure.clone())(initial_value, c);
}
initial_value
}
}
impl<F> Div<Polynomial<F>> for Polynomial<F>
where
F: Zero
+ One
+ PartialEq
+ AddAssign
+ Clone
+ Mul<Output = F>
+ MulAssign
+ Div<Output = F>
+ Neg<Output = F>
+ Sub<Output = F>,
{
type Output = Polynomial<F>;
fn div(self, denominator: Self) -> Self::Output {
if denominator.is_zero() {
panic!();
}
if self.is_zero() {
Self::zero();
}
let mut remainder = self.clone();
let mut quotient = Polynomial::<F>::zero();
while remainder.degree().unwrap() >= denominator.degree().unwrap() {
let shift = remainder.degree().unwrap() - denominator.degree().unwrap();
let quotient_coefficient = remainder.lc() / denominator.lc();
let monomial = Self::constant(quotient_coefficient).shift(shift);
quotient += monomial.clone();
remainder -= monomial * denominator.clone();
if remainder.is_zero() {
break;
}
}
quotient
}
}
fn vector_karatsuba<
F: Zero + AddAssign + Mul<Output = F> + Sub<Output = F> + Div<Output = F> + Clone,
>(
left: &[F],
right: &[F],
) -> Vec<F> {
let n = left.len();
if n <= 8 {
let mut product = vec![F::zero(); left.len() + right.len() - 1];
for (i, l) in left.iter().enumerate() {
for (j, r) in right.iter().enumerate() {
product[i + j] += l.clone() * r.clone();
}
}
return product;
}
let n_over_2 = n / 2;
let mut product = vec![F::zero(); 2 * n - 1];
let left_lo = &left[0..n_over_2];
let right_lo = &right[0..n_over_2];
let left_hi = &left[n_over_2..];
let right_hi = &right[n_over_2..];
let left_sum: Vec<F> =
left_lo.iter().zip(left_hi).map(|(a, b)| a.clone() + b.clone()).collect();
let right_sum: Vec<F> =
right_lo.iter().zip(right_hi).map(|(a, b)| a.clone() + b.clone()).collect();
let prod_lo = vector_karatsuba(left_lo, right_lo);
let prod_hi = vector_karatsuba(left_hi, right_hi);
let prod_mid: Vec<F> = vector_karatsuba(&left_sum, &right_sum)
.iter()
.zip(prod_lo.iter().zip(prod_hi.iter()))
.map(|(s, (l, h))| s.clone() - (l.clone() + h.clone()))
.collect();
for (i, l) in prod_lo.into_iter().enumerate() {
product[i] = l;
}
for (i, m) in prod_mid.into_iter().enumerate() {
product[i + n_over_2] += m;
}
for (i, h) in prod_hi.into_iter().enumerate() {
product[i + n] += h
}
product
}
impl From<Polynomial<FalconFelt>> for Polynomial<Felt> {
fn from(item: Polynomial<FalconFelt>) -> Self {
let res: Vec<Felt> =
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
Polynomial::new(res)
}
}
impl From<&Polynomial<FalconFelt>> for Polynomial<Felt> {
fn from(item: &Polynomial<FalconFelt>) -> Self {
let res: Vec<Felt> =
item.coefficients.iter().map(|a| Felt::from(a.value() as u16)).collect();
Polynomial::new(res)
}
}
impl From<Polynomial<i16>> for Polynomial<FalconFelt> {
fn from(item: Polynomial<i16>) -> Self {
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl From<&Polynomial<i16>> for Polynomial<FalconFelt> {
fn from(item: &Polynomial<i16>) -> Self {
let res: Vec<FalconFelt> = item.coefficients.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl From<Vec<i16>> for Polynomial<FalconFelt> {
fn from(item: Vec<i16>) -> Self {
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl From<&Vec<i16>> for Polynomial<FalconFelt> {
fn from(item: &Vec<i16>) -> Self {
let res: Vec<FalconFelt> = item.iter().map(|&a| FalconFelt::new(a)).collect();
Polynomial::new(res)
}
}
impl Polynomial<FalconFelt> {
pub fn norm_squared(&self) -> u64 {
self.coefficients
.iter()
.map(|&i| i.balanced_value() as i64)
.map(|i| (i * i) as u64)
.sum::<u64>()
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the coefficients of this polynomial as field elements.
pub fn to_elements(&self) -> Vec<Felt> {
self.coefficients.iter().map(|&a| Felt::from(a.value() as u16)).collect()
}
// POLYNOMIAL OPERATIONS
// --------------------------------------------------------------------------------------------
/// Multiplies two polynomials over Z_p\[x\] without reducing modulo p. Given that the degrees
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
/// than the Miden prime.
///
/// Note that this multiplication is not over Z_p\[x\]/(phi).
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
let mut c = [0; 2 * N];
for i in 0..N {
for j in 0..N {
c[i + j] += a.coefficients[i].value() as u64 * b.coefficients[j].value() as u64;
}
}
c
}
/// Reduces a polynomial, that is the product of two polynomials over Z_p\[x\], modulo
/// the irreducible polynomial phi. This results in an element in Z_p\[x\]/(phi).
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
let mut c = [FalconFelt::zero(); N];
let modulus = MODULUS as u16;
for i in 0..N {
let ai = a[N + i] % modulus as u64;
let neg_ai = (modulus - ai as u16) % modulus;
let bi = (a[i] % modulus as u64) as u16;
c[i] = FalconFelt::new(((neg_ai + bi) % modulus) as i16);
}
Self::new(c.to_vec())
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{FalconFelt, Polynomial, N};
#[test]
fn test_negacyclic_reduction() {
let coef1: [u8; N] = rand_utils::rand_array();
let coef2: [u8; N] = rand_utils::rand_array();
let poly1 = Polynomial::new(coef1.iter().map(|&a| FalconFelt::new(a as i16)).collect());
let poly2 = Polynomial::new(coef2.iter().map(|&a| FalconFelt::new(a as i16)).collect());
let prod = poly1.clone() * poly2.clone();
assert_eq!(
prod.reduce_by_cyclotomic(N),
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
);
}
}

View File

@@ -0,0 +1,299 @@
use core::f64::consts::LN_2;
#[cfg(not(feature = "std"))]
use num::Float;
use rand::Rng;
/// Samples an integer from {0, ..., 18} according to the distribution χ, which is close to
/// the half-Gaussian distribution on the natural numbers with mean 0 and standard deviation
/// equal to sigma_max.
fn base_sampler(bytes: [u8; 9]) -> i16 {
const RCDT: [u128; 18] = [
3024686241123004913666,
1564742784480091954050,
636254429462080897535,
199560484645026482916,
47667343854657281903,
8595902006365044063,
1163297957344668388,
117656387352093658,
8867391802663976,
496969357462633,
20680885154299,
638331848991,
14602316184,
247426747,
3104126,
28824,
198,
1,
];
let u = u128::from_be_bytes([vec![0u8; 7], bytes.to_vec()].concat().try_into().unwrap());
RCDT.into_iter().filter(|r| u < *r).count() as i16
}
/// Computes an integer approximation of 2^63 * ccs * exp(-x).
fn approx_exp(x: f64, ccs: f64) -> u64 {
// The constants C are used to approximate exp(-x); these
// constants are taken from FACCT (up to a scaling factor
// of 2^63):
// https://eprint.iacr.org/2018/1234
// https://github.com/raykzhao/gaussian
const C: [u64; 13] = [
0x00000004741183a3u64,
0x00000036548cfc06u64,
0x0000024fdcbf140au64,
0x0000171d939de045u64,
0x0000d00cf58f6f84u64,
0x000680681cf796e3u64,
0x002d82d8305b0feau64,
0x011111110e066fd0u64,
0x0555555555070f00u64,
0x155555555581ff00u64,
0x400000000002b400u64,
0x7fffffffffff4800u64,
0x8000000000000000u64,
];
let mut z: u64;
let mut y: u64;
let twoe63 = 1u64 << 63;
y = C[0];
z = f64::floor(x * (twoe63 as f64)) as u64;
for cu in C.iter().skip(1) {
let zy = (z as u128) * (y as u128);
y = cu - ((zy >> 63) as u64);
}
z = f64::floor((twoe63 as f64) * ccs) as u64;
(((z as u128) * (y as u128)) >> 63) as u64
}
/// A random bool that is true with probability ≈ ccs · exp(-x).
fn ber_exp(x: f64, ccs: f64, random_bytes: [u8; 7]) -> bool {
// 0.69314718055994530941 = ln(2)
let s = f64::floor(x / LN_2) as usize;
let r = x - LN_2 * (s as f64);
let shamt = usize::min(s, 63);
let z = ((((approx_exp(r, ccs) as u128) << 1) - 1) >> shamt) as u64;
let mut w = 0i16;
for (index, i) in (0..64).step_by(8).rev().enumerate() {
let byte = random_bytes[index];
w = (byte as i16) - (((z >> i) & 0xff) as i16);
if w != 0 {
break;
}
}
w < 0
}
/// Samples an integer from the Gaussian distribution with given mean (mu) and standard deviation
/// (sigma).
pub(crate) fn sampler_z<R: Rng>(mu: f64, sigma: f64, sigma_min: f64, rng: &mut R) -> i16 {
const SIGMA_MAX: f64 = 1.8205;
const INV_2SIGMA_MAX_SQ: f64 = 1f64 / (2f64 * SIGMA_MAX * SIGMA_MAX);
let isigma = 1f64 / sigma;
let dss = 0.5f64 * isigma * isigma;
let s = f64::floor(mu);
let r = mu - s;
let ccs = sigma_min * isigma;
loop {
let z0 = base_sampler(rng.gen());
let random_byte: u8 = rng.gen();
let b = (random_byte & 1) as i16;
let z = b + ((b << 1) - 1) * z0;
let zf_min_r = (z as f64) - r;
// x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
let x = zf_min_r * zf_min_r * dss - (z0 * z0) as f64 * INV_2SIGMA_MAX_SQ;
if ber_exp(x, ccs, rng.gen()) {
return z + (s as i16);
}
}
}
#[cfg(all(test, feature = "std"))]
mod test {
use alloc::vec::Vec;
use std::{thread::sleep, time::Duration};
use rand::RngCore;
use super::{approx_exp, ber_exp, sampler_z};
/// RNG used only for testing purposes, whereby the produced
/// string of random bytes is equal to the one it is initialized
/// with. Whatever you do, do not use this RNG in production.
struct UnsafeBufferRng {
buffer: Vec<u8>,
index: usize,
}
impl UnsafeBufferRng {
fn new(buffer: &[u8]) -> Self {
Self { buffer: buffer.to_vec(), index: 0 }
}
fn next(&mut self) -> u8 {
if self.buffer.len() <= self.index {
// panic!("Ran out of buffer.");
sleep(Duration::from_millis(10));
0u8
} else {
let return_value = self.buffer[self.index];
self.index += 1;
return_value
}
}
}
impl RngCore for UnsafeBufferRng {
fn next_u32(&mut self) -> u32 {
// let bytes: [u8; 4] = (0..4)
// .map(|_| self.next())
// .collect_vec()
// .try_into()
// .unwrap();
// u32::from_be_bytes(bytes)
u32::from_le_bytes([self.next(), 0, 0, 0])
}
fn next_u64(&mut self) -> u64 {
// let bytes: [u8; 8] = (0..8)
// .map(|_| self.next())
// .collect_vec()
// .try_into()
// .unwrap();
// u64::from_be_bytes(bytes)
u64::from_le_bytes([self.next(), 0, 0, 0, 0, 0, 0, 0])
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
for d in dest.iter_mut() {
*d = self.next();
}
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
for d in dest.iter_mut() {
*d = self.next();
}
Ok(())
}
}
#[test]
fn test_unsafe_buffer_rng() {
let seed_bytes = hex::decode("7FFECD162AE2").unwrap();
let mut rng = UnsafeBufferRng::new(&seed_bytes);
let generated_bytes: Vec<u8> = (0..seed_bytes.len()).map(|_| rng.next()).collect();
assert_eq!(seed_bytes, generated_bytes);
}
#[test]
fn test_approx_exp() {
let precision = 1u64 << 14;
// known answers were generated with the following sage script:
//```sage
// num_samples = 10
// precision = 200
// R = Reals(precision)
//
// print(f"let kats : [(f64, f64, u64);{num_samples}] = [")
// for i in range(num_samples):
// x = RDF.random_element(0.0, 0.693147180559945)
// ccs = RDF.random_element(0.0, 1.0)
// res = round(2^63 * R(ccs) * exp(R(-x)))
// print(f"({x}, {ccs}, {res}),")
// print("];")
// ```
let kats: [(f64, f64, u64); 10] = [
(0.2314993926072656, 0.8148006314615972, 5962140072160879737),
(0.2648875572812225, 0.12769669655309035, 903712282351034505),
(0.11251957513682391, 0.9264611470305881, 7635725498677341553),
(0.04353439307256617, 0.5306497137523327, 4685877322232397936),
(0.41834495299784347, 0.879438856118578, 5338392138535350986),
(0.32579398973228557, 0.16513412873289002, 1099603299296456803),
(0.5939508073919817, 0.029776019144967303, 151637565622779016),
(0.2932367999399056, 0.37123847662857923, 2553827649386670452),
(0.5005699297417507, 0.31447208863888976, 1758235618083658825),
(0.4876437338498085, 0.6159515298936868, 3488632981903743976),
];
for (x, ccs, answer) in kats {
let difference = (answer as i128) - (approx_exp(x, ccs) as i128);
assert!(
(difference * difference) as u64 <= precision * precision,
"answer: {answer} versus approximation: {}\ndifference: {} whereas precision: {}",
approx_exp(x, ccs),
difference,
precision
);
}
}
#[test]
fn test_ber_exp() {
let kats = [
(
1.268_314_048_020_498_4,
0.749_990_853_267_664_9,
hex::decode("ea000000000000").unwrap(),
false,
),
(
0.001_563_917_959_143_409_6,
0.749_990_853_267_664_9,
hex::decode("6c000000000000").unwrap(),
true,
),
(
0.017_921_215_753_999_235,
0.749_990_853_267_664_9,
hex::decode("c2000000000000").unwrap(),
false,
),
(
0.776_117_648_844_980_6,
0.751_181_554_542_520_8,
hex::decode("58000000000000").unwrap(),
true,
),
];
for (x, ccs, bytes, answer) in kats {
assert_eq!(answer, ber_exp(x, ccs, bytes.try_into().unwrap()));
}
}
#[test]
fn test_sampler_z() {
let sigma_min = 1.277833697;
// known answers from the doc, table 3.2, page 44
// https://falcon-sign.info/falcon.pdf
// The zeros were added to account for dropped bytes.
let kats = [
(-91.90471153063714,1.7037990414754918,hex::decode("0fc5442ff043d66e91d1ea000000000000cac64ea5450a22941edc6c").unwrap(),-92),
(-8.322564895434937,1.7037990414754918,hex::decode("f4da0f8d8444d1a77265c2000000000000ef6f98bbbb4bee7db8d9b3").unwrap(),-8),
(-19.096516109216804,1.7035823083824078,hex::decode("db47f6d7fb9b19f25c36d6000000000000b9334d477a8bc0be68145d").unwrap(),-20),
(-11.335543982423326, 1.7035823083824078, hex::decode("ae41b4f5209665c74d00dc000000000000c1a8168a7bb516b3190cb42c1ded26cd52000000000000aed770eca7dd334e0547bcc3c163ce0b").unwrap(), -12),
(7.9386734193997555, 1.6984647769450156, hex::decode("31054166c1012780c603ae0000000000009b833cec73f2f41ca5807c000000000000c89c92158834632f9b1555").unwrap(), 8),
(-28.990850086867255, 1.6984647769450156, hex::decode("737e9d68a50a06dbbc6477").unwrap(), -30),
(-9.071257914091655, 1.6980782114808988, hex::decode("a98ddd14bf0bf22061d632").unwrap(), -10),
(-43.88754568839566, 1.6980782114808988, hex::decode("3cbf6818a68f7ab9991514").unwrap(), -41),
(-58.17435547946095,1.7010983419195522,hex::decode("6f8633f5bfa5d26848668e0000000000003d5ddd46958e97630410587c").unwrap(),-61),
(-43.58664906684732, 1.7010983419195522, hex::decode("272bc6c25f5c5ee53f83c40000000000003a361fbc7cc91dc783e20a").unwrap(), -46),
(-34.70565203313315, 1.7009387219711465, hex::decode("45443c59574c2c3b07e2e1000000000000d9071e6d133dbe32754b0a").unwrap(), -34),
(-44.36009577368896, 1.7009387219711465, hex::decode("6ac116ed60c258e2cbaeab000000000000728c4823e6da36e18d08da0000000000005d0cc104e21cc7fd1f5ca8000000000000d9dbb675266c928448059e").unwrap(), -44),
(-21.783037079346236, 1.6958406126012802, hex::decode("68163bc1e2cbf3e18e7426").unwrap(), -23),
(-39.68827784633828, 1.6958406126012802, hex::decode("d6a1b51d76222a705a0259").unwrap(), -40),
(-18.488607061056847, 1.6955259305261838, hex::decode("f0523bfaa8a394bf4ea5c10000000000000f842366fde286d6a30803").unwrap(), -22),
(-48.39610939101591, 1.6955259305261838, hex::decode("87bd87e63374cee62127fc0000000000006931104aab64f136a0485b").unwrap(), -50),
];
for (mu, sigma, random_bytes, answer) in kats {
assert_eq!(
sampler_z(mu, sigma, sigma_min, &mut UnsafeBufferRng::new(&random_bytes)),
answer
);
}
}
}

View File

@@ -1,60 +1,105 @@
use crate::{
hash::rpo::Rpo256,
utils::{
collections::Vec, ByteReader, ByteWriter, Deserializable, DeserializationError,
Serializable,
},
Felt, StarkField, Word, ZERO,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
Felt, Word, ZERO,
};
#[cfg(feature = "std")]
mod ffi;
mod error;
mod hash_to_point;
mod keys;
mod polynomial;
mod math;
mod signature;
pub use error::FalconError;
pub use keys::{KeyPair, PublicKey};
pub use polynomial::Polynomial;
pub use signature::Signature;
pub use self::{
keys::{PubKeyPoly, PublicKey, SecretKey},
math::Polynomial,
signature::{Signature, SignatureHeader, SignaturePoly},
};
// CONSTANTS
// ================================================================================================
// The Falcon modulus.
const MODULUS: u16 = 12289;
const MODULUS_MINUS_1_OVER_TWO: u16 = 6144;
// The Falcon modulus p.
const MODULUS: i16 = 12289;
// Number of bits needed to encode an element in the Falcon field.
const FALCON_ENCODING_BITS: u32 = 14;
// The Falcon parameters for Falcon-512. This is the degree of the polynomial `phi := x^N + 1`
// defining the ring Z_p[x]/(phi).
const N: usize = 512;
const LOG_N: usize = 9;
const LOG_N: u8 = 9;
/// Length of nonce used for key-pair generation.
const NONCE_LEN: usize = 40;
const SIG_NONCE_LEN: usize = 40;
/// Number of filed elements used to encode a nonce.
const NONCE_ELEMENTS: usize = 8;
/// Public key length as a u8 vector.
const PK_LEN: usize = 897;
pub const PK_LEN: usize = 897;
/// Secret key length as a u8 vector.
const SK_LEN: usize = 1281;
pub const SK_LEN: usize = 1281;
/// Signature length as a u8 vector.
const SIG_LEN: usize = 626;
const SIG_POLY_BYTE_LEN: usize = 625;
/// Bound on the squared-norm of the signature.
const SIG_L2_BOUND: u64 = 34034726;
/// Standard deviation of the Gaussian over the lattice.
const SIGMA: f64 = 165.7366171829776;
// TYPE ALIASES
// ================================================================================================
type SignatureBytes = [u8; NONCE_LEN + SIG_LEN];
type PublicKeyBytes = [u8; PK_LEN];
type SecretKeyBytes = [u8; SK_LEN];
type NonceBytes = [u8; NONCE_LEN];
type NonceElements = [Felt; NONCE_ELEMENTS];
type ShortLatticeBasis = [Polynomial<i16>; 4];
// NONCE
// ================================================================================================
/// Nonce of the Falcon signature.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Nonce([u8; SIG_NONCE_LEN]);
impl Nonce {
/// Returns a new [Nonce] instantiated from the provided bytes.
pub fn new(bytes: [u8; SIG_NONCE_LEN]) -> Self {
Self(bytes)
}
/// Returns the underlying bytes of this nonce.
pub fn as_bytes(&self) -> &[u8; SIG_NONCE_LEN] {
&self.0
}
/// Converts byte representation of the nonce into field element representation.
///
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
/// of the nonce and interpreting them as field elements.
pub fn to_elements(&self) -> [Felt; NONCE_ELEMENTS] {
let mut buffer = [0_u8; 8];
let mut result = [ZERO; 8];
for (i, bytes) in self.0.chunks(5).enumerate() {
buffer[..5].copy_from_slice(bytes);
// we can safely (without overflow) create a new Felt from u64 value here since this
// value contains at most 5 bytes
result[i] = Felt::new(u64::from_le_bytes(buffer));
}
result
}
}
impl Serializable for &Nonce {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.0)
}
}
impl Deserializable for Nonce {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let bytes = source.read()?;
Ok(Self(bytes))
}
}

View File

@@ -1,277 +0,0 @@
use super::{FalconError, Felt, Vec, LOG_N, MODULUS, MODULUS_MINUS_1_OVER_TWO, N, PK_LEN};
use core::ops::{Add, Mul, Sub};
// FALCON POLYNOMIAL
// ================================================================================================
/// A polynomial over Z_p[x]/(phi) where phi := x^512 + 1
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Polynomial([u16; N]);
impl Polynomial {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Constructs a new polynomial from a list of coefficients.
///
/// # Safety
/// This constructor validates that the coefficients are in the valid range only in debug mode.
pub unsafe fn new(data: [u16; N]) -> Self {
for value in data {
debug_assert!(value < MODULUS);
}
Self(data)
}
/// Decodes raw bytes representing a public key into a polynomial in Z_p[x]/(phi).
///
/// # Errors
/// Returns an error if:
/// - The provided input is not exactly 897 bytes long.
/// - The first byte of the input is not equal to log2(512) i.e., 9.
/// - Any of the coefficients encoded in the provided input is greater than or equal to the
/// Falcon field modulus.
pub fn from_pub_key(input: &[u8]) -> Result<Self, FalconError> {
if input.len() != PK_LEN {
return Err(FalconError::PubKeyDecodingInvalidLength(input.len()));
}
if input[0] != LOG_N as u8 {
return Err(FalconError::PubKeyDecodingInvalidTag(input[0]));
}
let mut acc = 0_u32;
let mut acc_len = 0;
let mut output = [0_u16; N];
let mut output_idx = 0;
for &byte in input.iter().skip(1) {
acc = (acc << 8) | (byte as u32);
acc_len += 8;
if acc_len >= 14 {
acc_len -= 14;
let w = (acc >> acc_len) & 0x3FFF;
if w >= MODULUS as u32 {
return Err(FalconError::PubKeyDecodingInvalidCoefficient(w));
}
output[output_idx] = w as u16;
output_idx += 1;
}
}
if (acc & ((1u32 << acc_len) - 1)) == 0 {
Ok(Self(output))
} else {
Err(FalconError::PubKeyDecodingExtraData)
}
}
/// Decodes the signature into the coefficients of a polynomial in Z_p[x]/(phi). It assumes
/// that the signature has been encoded using the uncompressed format.
///
/// # Errors
/// Returns an error if:
/// - The signature has been encoded using a different algorithm than the reference compressed
/// encoding algorithm.
/// - The encoded signature polynomial is in Z_p[x]/(phi') where phi' = x^N' + 1 and N' != 512.
/// - While decoding the high bits of a coefficient, the current accumulated value of its
/// high bits is larger than 2048.
/// - The decoded coefficient is -0.
/// - The remaining unused bits in the last byte of `input` are non-zero.
pub fn from_signature(input: &[u8]) -> Result<Self, FalconError> {
let (encoding, log_n) = (input[0] >> 4, input[0] & 0b00001111);
if encoding != 0b0011 {
return Err(FalconError::SigDecodingIncorrectEncodingAlgorithm);
}
if log_n != 0b1001 {
return Err(FalconError::SigDecodingNotSupportedDegree(log_n));
}
let input = &input[41..];
let mut input_idx = 0;
let mut acc = 0u32;
let mut acc_len = 0;
let mut output = [0_u16; N];
for e in output.iter_mut() {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
let b = acc >> acc_len;
let s = b & 128;
let mut m = b & 127;
loop {
if acc_len == 0 {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
acc_len = 8;
}
acc_len -= 1;
if ((acc >> acc_len) & 1) != 0 {
break;
}
m += 128;
if m >= 2048 {
return Err(FalconError::SigDecodingTooBigHighBits(m));
}
}
if s != 0 && m == 0 {
return Err(FalconError::SigDecodingMinusZero);
}
*e = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
}
if (acc & ((1 << acc_len) - 1)) != 0 {
return Err(FalconError::SigDecodingNonZeroUnusedBitsLastByte);
}
Ok(Self(output))
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the coefficients of this polynomial as integers.
pub fn inner(&self) -> [u16; N] {
self.0
}
/// Returns the coefficients of this polynomial as field elements.
pub fn to_elements(&self) -> Vec<Felt> {
self.0.iter().map(|&a| Felt::from(a)).collect()
}
// POLYNOMIAL OPERATIONS
// --------------------------------------------------------------------------------------------
/// Multiplies two polynomials over Z_p[x] without reducing modulo p. Given that the degrees
/// of the input polynomials are less than 512 and their coefficients are less than the modulus
/// q equal to 12289, the resulting product polynomial is guaranteed to have coefficients less
/// than the Miden prime.
///
/// Note that this multiplication is not over Z_p[x]/(phi).
pub fn mul_modulo_p(a: &Self, b: &Self) -> [u64; 1024] {
let mut c = [0; 2 * N];
for i in 0..N {
for j in 0..N {
c[i + j] += a.0[i] as u64 * b.0[j] as u64;
}
}
c
}
/// Reduces a polynomial, that is the product of two polynomials over Z_p[x], modulo
/// the irreducible polynomial phi. This results in an element in Z_p[x]/(phi).
pub fn reduce_negacyclic(a: &[u64; 1024]) -> Self {
let mut c = [0; N];
for i in 0..N {
let ai = a[N + i] % MODULUS as u64;
let neg_ai = (MODULUS - ai as u16) % MODULUS;
let bi = (a[i] % MODULUS as u64) as u16;
c[i] = (neg_ai + bi) % MODULUS;
}
Self(c)
}
/// Computes the norm squared of a polynomial in Z_p[x]/(phi) after normalizing its
/// coefficients to be in the interval (-p/2, p/2].
pub fn sq_norm(&self) -> u64 {
let mut res = 0;
for e in self.0 {
if e > MODULUS_MINUS_1_OVER_TWO {
res += (MODULUS - e) as u64 * (MODULUS - e) as u64
} else {
res += e as u64 * e as u64
}
}
res
}
}
// Returns a polynomial representing the zero polynomial i.e. default element.
impl Default for Polynomial {
fn default() -> Self {
Self([0_u16; N])
}
}
/// Multiplication over Z_p[x]/(phi)
impl Mul for Polynomial {
type Output = Self;
fn mul(self, other: Self) -> <Self as Mul<Self>>::Output {
let mut result = [0_u16; N];
for j in 0..N {
for k in 0..N {
let i = (j + k) % N;
let a = self.0[j] as usize;
let b = other.0[k] as usize;
let q = MODULUS as usize;
let mut prod = a * b % q;
if (N - 1) < (j + k) {
prod = (q - prod) % q;
}
result[i] = ((result[i] as usize + prod) % q) as u16;
}
}
Polynomial(result)
}
}
/// Addition over Z_p[x]/(phi)
impl Add for Polynomial {
type Output = Self;
fn add(self, other: Self) -> <Self as Add<Self>>::Output {
let mut res = self;
res.0.iter_mut().zip(other.0.iter()).for_each(|(x, y)| *x = (*x + *y) % MODULUS);
res
}
}
/// Subtraction over Z_p[x]/(phi)
impl Sub for Polynomial {
type Output = Self;
fn sub(self, other: Self) -> <Self as Add<Self>>::Output {
let mut res = self;
res.0
.iter_mut()
.zip(other.0.iter())
.for_each(|(x, y)| *x = (*x + MODULUS - *y) % MODULUS);
res
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{Polynomial, N};
#[test]
fn test_negacyclic_reduction() {
let coef1: [u16; N] = rand_utils::rand_array();
let coef2: [u16; N] = rand_utils::rand_array();
let poly1 = Polynomial(coef1);
let poly2 = Polynomial(coef2);
assert_eq!(
poly1 * poly2,
Polynomial::reduce_negacyclic(&Polynomial::mul_modulo_p(&poly1, &poly2))
);
}
}

View File

@@ -1,262 +1,375 @@
use alloc::{string::ToString, vec::Vec};
use core::ops::Deref;
use num::Zero;
use super::{
ByteReader, ByteWriter, Deserializable, DeserializationError, NonceBytes, NonceElements,
Polynomial, PublicKeyBytes, Rpo256, Serializable, SignatureBytes, StarkField, Word, MODULUS, N,
SIG_L2_BOUND, ZERO,
hash_to_point::hash_to_point_rpo256,
keys::PubKeyPoly,
math::{FalconFelt, FastFft, Polynomial},
ByteReader, ByteWriter, Deserializable, DeserializationError, Felt, Nonce, Rpo256,
Serializable, Word, LOG_N, MODULUS, N, SIG_L2_BOUND, SIG_POLY_BYTE_LEN,
};
use crate::utils::string::ToString;
// FALCON SIGNATURE
// ================================================================================================
/// An RPO Falcon512 signature over a message.
///
/// The signature is a pair of polynomials (s1, s2) in (Z_p[x]/(phi))^2, where:
/// The signature is a pair of polynomials (s1, s2) in (Z_p\[x\]/(phi))^2 a nonce `r`, and a public
/// key polynomial `h` where:
/// - p := 12289
/// - phi := x^512 + 1
/// - s1 = c - s2 * h
/// - h is a polynomial representing the public key and c is a polynomial that is the hash-to-point
/// of the message being signed.
///
/// The signature verifies if and only if:
/// The signature verifies against a public key `pk` if and only if:
/// 1. s1 = c - s2 * h
/// 2. |s1|^2 + |s2|^2 <= SIG_L2_BOUND
///
/// where |.| is the norm.
/// where |.| is the norm and:
/// - c = HashToPoint(r || message)
/// - pk = Rpo256::hash(h)
///
/// [Signature] also includes the extended public key which is serialized as:
/// Here h is a polynomial representing the public key and pk is its digest using the Rpo256 hash
/// function. c is a polynomial that is the hash-to-point of the message being signed.
///
/// The polynomial h is serialized as:
/// 1. 1 byte representing the log2(512) i.e., 9.
/// 2. 896 bytes for the public key. This is decoded into the `h` polynomial above.
/// 2. 896 bytes for the public key itself.
///
/// The actual signature is serialized as:
/// The signature is serialized as:
/// 1. A header byte specifying the algorithm used to encode the coefficients of the `s2` polynomial
/// together with the degree of the irreducible polynomial phi.
/// The general format of this byte is 0b0cc1nnnn where:
/// a. cc is either 01 when the compressed encoding algorithm is used and 10 when the
/// uncompressed algorithm is used.
/// b. nnnn is log2(N) where N is the degree of the irreducible polynomial phi.
/// The current implementation works always with cc equal to 0b01 and nnnn equal to 0b1001 and
/// thus the header byte is always equal to 0b00111001.
/// together with the degree of the irreducible polynomial phi. For RPO Falcon512, the header
/// byte is set to `10111001` which differentiates it from the standardized instantiation of the
/// Falcon signature.
/// 2. 40 bytes for the nonce.
/// 3. 625 bytes encoding the `s2` polynomial above.
/// 4. 625 bytes encoding the `s2` polynomial above.
///
/// The total size of the signature (including the extended public key) is 1563 bytes.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
pub(super) pk: PublicKeyBytes,
pub(super) sig: SignatureBytes,
header: SignatureHeader,
nonce: Nonce,
s2: SignaturePoly,
h: PubKeyPoly,
}
impl Signature {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
pub fn new(nonce: Nonce, h: PubKeyPoly, s2: SignaturePoly) -> Signature {
Self {
header: SignatureHeader::default(),
nonce,
s2,
h,
}
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the public key polynomial h.
pub fn pub_key_poly(&self) -> Polynomial {
// TODO: memoize
// we assume that the signature was constructed with a valid public key, and thus
// expect() is OK here.
Polynomial::from_pub_key(&self.pk).expect("invalid public key")
}
/// Returns the nonce component of the signature represented as field elements.
///
/// Nonce bytes are converted to field elements by taking consecutive 5 byte chunks
/// of the nonce and interpreting them as field elements.
pub fn nonce(&self) -> NonceElements {
// we assume that the signature was constructed with a valid signature, and thus
// expect() is OK here.
let nonce = self.sig[1..41].try_into().expect("invalid signature");
decode_nonce(nonce)
pub fn pk_poly(&self) -> &PubKeyPoly {
&self.h
}
// Returns the polynomial representation of the signature in Z_p[x]/(phi).
pub fn sig_poly(&self) -> Polynomial {
// TODO: memoize
// we assume that the signature was constructed with a valid signature, and thus
// expect() is OK here.
Polynomial::from_signature(&self.sig).expect("invalid signature")
pub fn sig_poly(&self) -> &Polynomial<FalconFelt> {
&self.s2
}
// HASH-TO-POINT
// --------------------------------------------------------------------------------------------
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message.
pub fn hash_to_point(&self, message: Word) -> Polynomial {
hash_to_point(message, &self.nonce())
/// Returns the nonce component of the signature.
pub fn nonce(&self) -> &Nonce {
&self.nonce
}
// SIGNATURE VERIFICATION
// --------------------------------------------------------------------------------------------
/// Returns true if this signature is a valid signature for the specified message generated
/// against key pair matching the specified public key commitment.
/// against the secret key matching the specified public key commitment.
pub fn verify(&self, message: Word, pubkey_com: Word) -> bool {
// Make sure the expanded public key matches the provided public key commitment
let h = self.pub_key_poly();
let h_digest: Word = Rpo256::hash_elements(&h.to_elements()).into();
// compute the hash of the public key polynomial
let h_felt: Polynomial<Felt> = (&**self.pk_poly()).into();
let h_digest: Word = Rpo256::hash_elements(&h_felt.coefficients).into();
if h_digest != pubkey_com {
return false;
}
// Make sure the signature is valid
let s2 = self.sig_poly();
let c = self.hash_to_point(message);
let s1 = c - s2 * h;
let sq_norm = s1.sq_norm() + s2.sq_norm();
sq_norm <= SIG_L2_BOUND
let c = hash_to_point_rpo256(message, &self.nonce);
h_digest == pubkey_com && verify_helper(&c, &self.s2, self.pk_poly())
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for Signature {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.pk);
target.write_bytes(&self.sig);
target.write(&self.header);
target.write(&self.nonce);
target.write(&self.s2);
target.write(&self.h);
}
}
impl Deserializable for Signature {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let pk: PublicKeyBytes = source.read_array()?;
let sig: SignatureBytes = source.read_array()?;
let header = source.read()?;
let nonce = source.read()?;
let s2 = source.read()?;
let h = source.read()?;
// make sure public key and signature can be decoded correctly
Polynomial::from_pub_key(&pk)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
Polynomial::from_signature(&sig[41..])
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))?;
Ok(Self { header, nonce, s2, h })
}
}
Ok(Self { pk, sig })
// SIGNATURE HEADER
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignatureHeader(u8);
impl Default for SignatureHeader {
/// According to section 3.11.3 in the specification [1], the signature header has the format
/// `0cc1nnnn` where:
///
/// 1. `cc` signifies the encoding method. `01` denotes using the compression encoding method
/// and `10` denotes encoding using the uncompressed method.
/// 2. `nnnn` encodes `LOG_N`.
///
/// For RPO Falcon 512 we use compression encoding and N = 512. Moreover, to differentiate the
/// RPO Falcon variant from the reference variant using SHAKE256, we flip the first bit in the
/// header. Thus, for RPO Falcon 512 the header is `10111001`
///
/// [1]: https://falcon-sign.info/falcon.pdf
fn default() -> Self {
Self(0b1011_1001)
}
}
impl Serializable for &SignatureHeader {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_u8(self.0)
}
}
impl Deserializable for SignatureHeader {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let header = source.read_u8()?;
let (encoding, log_n) = (header >> 4, header & 0b00001111);
if encoding != 0b1011 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: not supported encoding algorithm".to_string(),
));
}
if log_n != LOG_N {
return Err(DeserializationError::InvalidValue(
format!("Failed to decode signature: only supported irreducible polynomial degree is 512, 2^{log_n} was provided")
));
}
Ok(Self(header))
}
}
// SIGNATURE POLYNOMIAL
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SignaturePoly(pub Polynomial<FalconFelt>);
impl Deref for SignaturePoly {
type Target = Polynomial<FalconFelt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Polynomial<FalconFelt>> for SignaturePoly {
fn from(pk_poly: Polynomial<FalconFelt>) -> Self {
Self(pk_poly)
}
}
impl TryFrom<&[i16; N]> for SignaturePoly {
type Error = ();
fn try_from(coefficients: &[i16; N]) -> Result<Self, Self::Error> {
if are_coefficients_valid(coefficients) {
Ok(Self(coefficients.to_vec().into()))
} else {
Err(())
}
}
}
impl Serializable for &SignaturePoly {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
let sig_coeff: Vec<i16> = self.0.coefficients.iter().map(|a| a.balanced_value()).collect();
let mut sk_bytes = vec![0_u8; SIG_POLY_BYTE_LEN];
let mut acc = 0;
let mut acc_len = 0;
let mut v = 0;
let mut t;
let mut w;
// For each coefficient of x:
// - the sign is encoded on 1 bit
// - the 7 lower bits are encoded naively (binary)
// - the high bits are encoded in unary encoding
//
// Algorithm 17 p. 47 of the specification [1].
//
// [1]: https://falcon-sign.info/falcon.pdf
for &c in sig_coeff.iter() {
acc <<= 1;
t = c;
if t < 0 {
t = -t;
acc |= 1;
}
w = t as u16;
acc <<= 7;
let mask = 127_u32;
acc |= (w as u32) & mask;
w >>= 7;
acc_len += 8;
acc <<= w + 1;
acc |= 1;
acc_len += w + 1;
while acc_len >= 8 {
acc_len -= 8;
sk_bytes[v] = (acc >> acc_len) as u8;
v += 1;
}
}
if acc_len > 0 {
sk_bytes[v] = (acc << (8 - acc_len)) as u8;
}
target.write_bytes(&sk_bytes);
}
}
impl Deserializable for SignaturePoly {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let input = source.read_array::<SIG_POLY_BYTE_LEN>()?;
let mut input_idx = 0;
let mut acc = 0u32;
let mut acc_len = 0;
let mut coefficients = [FalconFelt::zero(); N];
// Algorithm 18 p. 48 of the specification [1].
//
// [1]: https://falcon-sign.info/falcon.pdf
for c in coefficients.iter_mut() {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
let b = acc >> acc_len;
let s = b & 128;
let mut m = b & 127;
loop {
if acc_len == 0 {
acc = (acc << 8) | (input[input_idx] as u32);
input_idx += 1;
acc_len = 8;
}
acc_len -= 1;
if ((acc >> acc_len) & 1) != 0 {
break;
}
m += 128;
if m >= 2048 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: high bits {m} exceed 2048".to_string(),
));
}
}
if s != 0 && m == 0 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: -0 is forbidden".to_string(),
));
}
let felt = if s != 0 { (MODULUS as u32 - m) as u16 } else { m as u16 };
*c = FalconFelt::new(felt as i16);
}
if (acc & ((1 << acc_len) - 1)) != 0 {
return Err(DeserializationError::InvalidValue(
"Failed to decode signature: Non-zero unused bits in the last byte".to_string(),
));
}
Ok(Polynomial::new(coefficients.to_vec()).into())
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Returns a polynomial in Z_p[x]/(phi) representing the hash of the provided message and
/// nonce.
fn hash_to_point(message: Word, nonce: &NonceElements) -> Polynomial {
let mut state = [ZERO; Rpo256::STATE_WIDTH];
/// Takes the hash-to-point polynomial `c` of a message, the signature polynomial over
/// the message `s2` and a public key polynomial and returns `true` is the signature is a valid
/// signature for the given parameters, otherwise it returns `false`.
fn verify_helper(c: &Polynomial<FalconFelt>, s2: &SignaturePoly, h: &PubKeyPoly) -> bool {
let h_fft = h.fft();
let s2_fft = s2.fft();
let c_fft = c.fft();
// absorb the nonce into the state
for (&n, s) in nonce.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
*s = n;
}
Rpo256::apply_permutation(&mut state);
// compute the signature polynomial s1 using s1 = c - s2 * h
let s1_fft = c_fft - s2_fft.hadamard_mul(&h_fft);
let s1 = s1_fft.ifft();
// absorb message into the state
for (&m, s) in message.iter().zip(state[Rpo256::RATE_RANGE].iter_mut()) {
*s = m;
// compute the norm squared of (s1, s2)
let length_squared_s1 = s1.norm_squared();
let length_squared_s2 = s2.norm_squared();
let length_squared = length_squared_s1 + length_squared_s2;
length_squared < SIG_L2_BOUND
}
/// Checks whether a set of coefficients is a valid one for a signature polynomial.
fn are_coefficients_valid(x: &[i16]) -> bool {
if x.len() != N {
return false;
}
// squeeze the coefficients of the polynomial
let mut i = 0;
let mut res = [0_u16; N];
for _ in 0..64 {
Rpo256::apply_permutation(&mut state);
for a in &state[Rpo256::RATE_RANGE] {
res[i] = (a.as_int() % MODULUS as u64) as u16;
i += 1;
for &c in x {
if !(-2047..=2047).contains(&c) {
return false;
}
}
// using the raw constructor is OK here because we reduce all coefficients by the modulus above
unsafe { Polynomial::new(res) }
}
/// Converts byte representation of the nonce into field element representation.
fn decode_nonce(nonce: &NonceBytes) -> NonceElements {
let mut buffer = [0_u8; 8];
let mut result = [ZERO; 8];
for (i, bytes) in nonce.chunks(5).enumerate() {
buffer[..5].copy_from_slice(bytes);
result[i] = u64::from_le_bytes(buffer).into();
}
result
true
}
// TESTS
// ================================================================================================
#[cfg(all(test, feature = "std"))]
#[cfg(test)]
mod tests {
use super::{
super::{ffi::*, Felt},
*,
};
use libc::c_void;
use rand_utils::rand_vector;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
// Wrappers for unsafe functions
impl Rpo128Context {
/// Initializes the RPO state.
pub fn init() -> Self {
let mut ctx = Rpo128Context { content: [0u64; 13] };
unsafe {
rpo128_init(&mut ctx as *mut Rpo128Context);
}
ctx
}
/// Absorbs data into the RPO state.
pub fn absorb(&mut self, data: &[u8]) {
unsafe {
rpo128_absorb(
self as *mut Rpo128Context,
data.as_ptr() as *const c_void,
data.len(),
)
}
}
/// Finalizes the RPO state to prepare for squeezing.
pub fn finalize(&mut self) {
unsafe { rpo128_finalize(self as *mut Rpo128Context) }
}
}
use super::{super::SecretKey, *};
#[test]
fn test_hash_to_point() {
// Create a random message and transform it into a u8 vector
let msg_felts: Word = rand_vector::<Felt>(4).try_into().unwrap();
let msg_bytes = msg_felts.iter().flat_map(|e| e.as_int().to_le_bytes()).collect::<Vec<_>>();
fn test_serialization_round_trip() {
let seed = [0_u8; 32];
let mut rng = ChaCha20Rng::from_seed(seed);
// Create a nonce i.e. a [u8; 40] array and pack into a [Felt; 8] array.
let nonce: [u8; 40] = rand_vector::<u8>(40).try_into().unwrap();
let mut buffer = [0_u8; 64];
for i in 0..8 {
buffer[8 * i] = nonce[5 * i];
buffer[8 * i + 1] = nonce[5 * i + 1];
buffer[8 * i + 2] = nonce[5 * i + 2];
buffer[8 * i + 3] = nonce[5 * i + 3];
buffer[8 * i + 4] = nonce[5 * i + 4];
}
// Initialize the RPO state
let mut rng = Rpo128Context::init();
// Absorb the nonce and message into the RPO state
rng.absorb(&buffer);
rng.absorb(&msg_bytes);
rng.finalize();
// Generate the coefficients of the hash-to-point polynomial.
let mut res: [u16; N] = [0; N];
unsafe {
PQCLEAN_FALCON512_CLEAN_hash_to_point_rpo(
&mut rng as *mut Rpo128Context,
res.as_mut_ptr(),
9,
);
}
// Check that the coefficients are correct
let nonce = decode_nonce(&nonce);
assert_eq!(res, hash_to_point(msg_felts, &nonce).inner());
let sk = SecretKey::with_rng(&mut rng);
let signature = sk.sign_with_rng(Word::default(), &mut rng);
let serialized = signature.to_bytes();
let deserialized = Signature::read_from_bytes(&serialized).unwrap();
assert_eq!(signature.sig_poly(), deserialized.sig_poly());
}
}

24
src/dsa/rpo_stark/mod.rs Normal file
View File

@@ -0,0 +1,24 @@
mod signature;
pub use signature::{PublicKey, SecretKey, Signature};
mod stark;
pub use stark::{PublicInputs, RescueAir};
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::SecretKey;
use crate::Word;
#[test]
fn test_signature() {
let sk = SecretKey::new(Word::default());
let message = Word::default();
let signature = sk.sign(message);
let pk = sk.public_key();
assert!(pk.verify(message, &signature))
}
}

View File

@@ -0,0 +1,173 @@
use rand::{distributions::Uniform, prelude::Distribution, Rng};
use winter_air::{FieldExtension, ProofOptions};
use winter_math::{fields::f64::BaseElement, FieldElement};
use winter_prover::Proof;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use crate::{
dsa::rpo_stark::stark::RpoSignatureScheme,
hash::{rpo::Rpo256, DIGEST_SIZE},
StarkField, Word, ZERO,
};
// CONSTANTS
// ================================================================================================
/// Specifies the parameters of the STARK underlying the signature scheme. These parameters provide
/// at least 102 bits of security under the conjectured security of the toy protocol in
/// the ethSTARK paper [1].
///
/// [1]: https://eprint.iacr.org/2021/582
pub const PROOF_OPTIONS: ProofOptions =
ProofOptions::new(30, 8, 12, FieldExtension::Quadratic, 4, 7, true);
// PUBLIC KEY
// ================================================================================================
/// A public key for verifying signatures.
///
/// The public key is a [Word] (i.e., 4 field elements) that is the hash of the secret key.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PublicKey(Word);
impl PublicKey {
/// Returns the [Word] defining the public key.
pub fn inner(&self) -> Word {
self.0
}
}
impl PublicKey {
/// Verifies the provided signature against provided message and this public key.
pub fn verify(&self, message: Word, signature: &Signature) -> bool {
signature.verify(message, *self)
}
}
impl Serializable for PublicKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
}
}
impl Deserializable for PublicKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let pk = <Word>::read_from(source)?;
Ok(Self(pk))
}
}
// SECRET KEY
// ================================================================================================
/// A secret key for generating signatures.
///
/// The secret key is a [Word] (i.e., 4 field elements).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SecretKey(Word);
impl SecretKey {
/// Generates a secret key from OS-provided randomness.
pub fn new(word: Word) -> Self {
Self(word)
}
/// Generates a secret key from a [Word].
#[cfg(feature = "std")]
pub fn random() -> Self {
use rand::{rngs::StdRng, SeedableRng};
let mut rng = StdRng::from_entropy();
Self::with_rng(&mut rng)
}
/// Generates a secret_key using the provided random number generator `Rng`.
pub fn with_rng<R: Rng>(rng: &mut R) -> Self {
let mut sk = [ZERO; 4];
let uni_dist = Uniform::from(0..BaseElement::MODULUS);
for s in sk.iter_mut() {
let sampled_integer = uni_dist.sample(rng);
*s = BaseElement::new(sampled_integer);
}
Self(sk)
}
/// Computes the public key corresponding to this secret key.
pub fn public_key(&self) -> PublicKey {
let mut elements = [BaseElement::ZERO; 8];
elements[..DIGEST_SIZE].copy_from_slice(&self.0);
let pk = Rpo256::hash_elements(&elements);
PublicKey(pk.into())
}
/// Signs a message with this secret key.
pub fn sign(&self, message: Word) -> Signature {
let signature: RpoSignatureScheme<Rpo256> = RpoSignatureScheme::new(PROOF_OPTIONS);
let proof = signature.sign(self.0, message);
Signature { proof }
}
}
impl Serializable for SecretKey {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
}
}
impl Deserializable for SecretKey {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let sk = <Word>::read_from(source)?;
Ok(Self(sk))
}
}
// SIGNATURE
// ================================================================================================
/// An RPO STARK-based signature over a message.
///
/// The signature is a STARK proof of knowledge of a pre-image given an image where the map is
/// the RPO permutation, the pre-image is the secret key and the image is the public key.
/// The current implementation follows the description in [1] but relies on the conjectured security
/// of the toy protocol in the ethSTARK paper [2], which gives us using the parameter set
/// given in `PROOF_OPTIONS` a signature with $102$ bits of average-case existential unforgeability
/// security against $2^{113}$-query bound adversaries that can obtain up to $2^{64}$ signatures
/// under the same public key.
///
/// [1]: https://eprint.iacr.org/2024/1553
/// [2]: https://eprint.iacr.org/2021/582
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
proof: Proof,
}
impl Signature {
/// Returns the STARK proof constituting the signature.
pub fn inner(&self) -> Proof {
self.proof.clone()
}
/// Returns true if this signature is a valid signature for the specified message generated
/// against the secret key matching the specified public key.
pub fn verify(&self, message: Word, pk: PublicKey) -> bool {
let signature: RpoSignatureScheme<Rpo256> = RpoSignatureScheme::new(PROOF_OPTIONS);
let res = signature.verify(pk.inner(), message, self.proof.clone());
res.is_ok()
}
}
impl Serializable for Signature {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.proof.write_into(target);
}
}
impl Deserializable for Signature {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let proof = Proof::read_from(source)?;
Ok(Self { proof })
}
}

View File

@@ -0,0 +1,198 @@
use alloc::vec::Vec;
use winter_math::{fields::f64::BaseElement, FieldElement, ToElements};
use winter_prover::{
Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo,
TransitionConstraintDegree,
};
use crate::{
hash::{ARK1, ARK2, MDS, STATE_WIDTH},
Word, ZERO,
};
// CONSTANTS
// ================================================================================================
pub const HASH_CYCLE_LEN: usize = 8;
// AIR
// ================================================================================================
pub struct RescueAir {
context: AirContext<BaseElement>,
pub_key: Word,
}
impl Air for RescueAir {
type BaseField = BaseElement;
type PublicInputs = PublicInputs;
type GkrProof = ();
type GkrVerifier = ();
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self {
let degrees = vec![
// Apply RPO rounds.
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
TransitionConstraintDegree::new(7),
];
assert_eq!(STATE_WIDTH, trace_info.width());
let context = AirContext::new(trace_info, degrees, 12, options);
let context = context.set_num_transition_exemptions(1);
RescueAir { context, pub_key: pub_inputs.pub_key }
}
fn context(&self) -> &AirContext<Self::BaseField> {
&self.context
}
fn evaluate_transition<E: FieldElement + From<Self::BaseField>>(
&self,
frame: &EvaluationFrame<E>,
periodic_values: &[E],
result: &mut [E],
) {
let current = frame.current();
let next = frame.next();
// expected state width is 12 field elements
debug_assert_eq!(STATE_WIDTH, current.len());
debug_assert_eq!(STATE_WIDTH, next.len());
enforce_rpo_round(frame, result, periodic_values);
}
fn get_assertions(&self) -> Vec<Assertion<Self::BaseField>> {
let initial_step = 0;
let last_step = self.trace_length() - 1;
vec![
// Assert that the capacity as well as the second half of the rate portion of the state
// are initialized to `ZERO`.The first half of the rate is unconstrained as it will
// contain the secret key
Assertion::single(0, initial_step, Self::BaseField::ZERO),
Assertion::single(1, initial_step, Self::BaseField::ZERO),
Assertion::single(2, initial_step, Self::BaseField::ZERO),
Assertion::single(3, initial_step, Self::BaseField::ZERO),
Assertion::single(8, initial_step, Self::BaseField::ZERO),
Assertion::single(9, initial_step, Self::BaseField::ZERO),
Assertion::single(10, initial_step, Self::BaseField::ZERO),
Assertion::single(11, initial_step, Self::BaseField::ZERO),
// Assert that the public key is the correct one
Assertion::single(4, last_step, self.pub_key[0]),
Assertion::single(5, last_step, self.pub_key[1]),
Assertion::single(6, last_step, self.pub_key[2]),
Assertion::single(7, last_step, self.pub_key[3]),
]
}
fn get_periodic_column_values(&self) -> Vec<Vec<Self::BaseField>> {
get_round_constants()
}
}
pub struct PublicInputs {
pub(crate) pub_key: Word,
pub(crate) msg: Word,
}
impl PublicInputs {
pub fn new(pub_key: Word, msg: Word) -> Self {
Self { pub_key, msg }
}
}
impl ToElements<BaseElement> for PublicInputs {
fn to_elements(&self) -> Vec<BaseElement> {
let mut res = self.pub_key.to_vec();
res.extend_from_slice(self.msg.as_ref());
res
}
}
// HELPER EVALUATORS
// ------------------------------------------------------------------------------------------------
/// Enforces constraints for a single round of the Rescue Prime Optimized hash functions.
pub fn enforce_rpo_round<E: FieldElement + From<BaseElement>>(
frame: &EvaluationFrame<E>,
result: &mut [E],
ark: &[E],
) {
// compute the state that should result from applying the first 5 operations of the RPO round to
// the current hash state.
let mut step1 = [E::ZERO; STATE_WIDTH];
step1.copy_from_slice(frame.current());
apply_mds(&mut step1);
// add constants
for i in 0..STATE_WIDTH {
step1[i] += ark[i];
}
apply_sbox(&mut step1);
apply_mds(&mut step1);
// add constants
for i in 0..STATE_WIDTH {
step1[i] += ark[STATE_WIDTH + i];
}
// compute the state that should result from applying the inverse of the last operation of the
// RPO round to the next step of the computation.
let mut step2 = [E::ZERO; STATE_WIDTH];
step2.copy_from_slice(frame.next());
apply_sbox(&mut step2);
// make sure that the results are equal.
for i in 0..STATE_WIDTH {
result[i] = step2[i] - step1[i]
}
}
#[inline(always)]
fn apply_sbox<E: FieldElement + From<BaseElement>>(state: &mut [E; STATE_WIDTH]) {
state.iter_mut().for_each(|v| {
let t2 = v.square();
let t4 = t2.square();
*v *= t2 * t4;
});
}
#[inline(always)]
fn apply_mds<E: FieldElement + From<BaseElement>>(state: &mut [E; STATE_WIDTH]) {
let mut result = [E::ZERO; STATE_WIDTH];
result.iter_mut().zip(MDS).for_each(|(r, mds_row)| {
state.iter().zip(mds_row).for_each(|(&s, m)| {
*r += E::from(m) * s;
});
});
*state = result
}
/// Returns RPO round constants arranged in column-major form.
pub fn get_round_constants() -> Vec<Vec<BaseElement>> {
let mut constants = Vec::new();
for _ in 0..(STATE_WIDTH * 2) {
constants.push(vec![ZERO; HASH_CYCLE_LEN]);
}
#[allow(clippy::needless_range_loop)]
for i in 0..HASH_CYCLE_LEN - 1 {
for j in 0..STATE_WIDTH {
constants[j][i] = ARK1[i][j];
constants[j + STATE_WIDTH][i] = ARK2[i][j];
}
}
constants
}

View File

@@ -0,0 +1,98 @@
use alloc::vec::Vec;
use core::marker::PhantomData;
use prover::RpoSignatureProver;
use rand_chacha::ChaCha20Rng;
use winter_crypto::{ElementHasher, SaltedMerkleTree};
use winter_math::fields::f64::BaseElement;
use winter_prover::{Proof, ProofOptions, Prover};
use winter_utils::Serializable;
use winter_verifier::{verify, AcceptableOptions, VerifierError};
use crate::{
hash::{rpo::Rpo256, DIGEST_SIZE},
rand::RpoRandomCoin,
};
mod air;
pub use air::{PublicInputs, RescueAir};
mod prover;
/// Represents an abstract STARK-based signature scheme with knowledge of RPO pre-image as
/// the hard relation.
pub struct RpoSignatureScheme<H: ElementHasher> {
options: ProofOptions,
_h: PhantomData<H>,
}
impl<H: ElementHasher<BaseField = BaseElement> + Sync> RpoSignatureScheme<H> {
pub fn new(options: ProofOptions) -> Self {
RpoSignatureScheme { options, _h: PhantomData }
}
pub fn sign(&self, sk: [BaseElement; DIGEST_SIZE], msg: [BaseElement; DIGEST_SIZE]) -> Proof {
// create a prover
let prover = RpoSignatureProver::<H>::new(msg, self.options.clone());
// generate execution trace
let trace = prover.build_trace(sk);
// generate the initial seed for the PRNG used for zero-knowledge
let seed: [u8; 32] = generate_seed(sk, msg);
// generate the proof
prover.prove(trace, Some(seed)).expect("failed to generate the signature")
}
pub fn verify(
&self,
pub_key: [BaseElement; DIGEST_SIZE],
msg: [BaseElement; DIGEST_SIZE],
proof: Proof,
) -> Result<(), VerifierError> {
// we make sure that the parameters used in generating the proof match the expected ones
if *proof.options() != self.options {
return Err(VerifierError::UnacceptableProofOptions);
}
let pub_inputs = PublicInputs { pub_key, msg };
let acceptable_options = AcceptableOptions::OptionSet(vec![proof.options().clone()]);
verify::<RescueAir, Rpo256, RpoRandomCoin, SaltedMerkleTree<Rpo256, ChaCha20Rng>>(
proof,
pub_inputs,
&acceptable_options,
)
}
}
/// Deterministically generates a seed for seeding the PRNG used for zero-knowledge.
///
/// This uses the argument described in [RFC 6979](https://datatracker.ietf.org/doc/html/rfc6979#section-3.5)
/// § 3.5 where the concatenation of the private key and the hashed message, i.e., sk || H(m), is
/// used in order to construct the initial seed of a PRNG.
///
/// Note that we hash in also a context string in order to domain separate between different
/// instantiations of the signature scheme.
#[inline]
pub fn generate_seed(sk: [BaseElement; DIGEST_SIZE], msg: [BaseElement; DIGEST_SIZE]) -> [u8; 32] {
let context_bytes = "
Seed for PRNG used for Zero-knowledge in RPO-STARK signature scheme:
1. Version: Conjectured security
2. FRI queries: 30
3. Blowup factor: 8
4. Grinding bits: 12
5. Field extension degree: 2
6. FRI folding factor: 4
7. FRI remainder polynomial max degree: 7
"
.to_bytes();
let sk_bytes = sk.to_bytes();
let msg_bytes = msg.to_bytes();
let total_length = context_bytes.len() + sk_bytes.len() + msg_bytes.len();
let mut buffer = Vec::with_capacity(total_length);
buffer.extend_from_slice(&context_bytes);
buffer.extend_from_slice(&sk_bytes);
buffer.extend_from_slice(&msg_bytes);
blake3::hash(&buffer).into()
}

View File

@@ -0,0 +1,148 @@
use core::marker::PhantomData;
use rand_chacha::ChaCha20Rng;
use winter_air::{
AuxRandElements, ConstraintCompositionCoefficients, PartitionOptions, ZkParameters,
};
use winter_crypto::{ElementHasher, SaltedMerkleTree};
use winter_math::{fields::f64::BaseElement, FieldElement};
use winter_prover::{
matrix::ColMatrix, CompositionPoly, CompositionPolyTrace, DefaultConstraintCommitment,
DefaultConstraintEvaluator, DefaultTraceLde, ProofOptions, Prover, StarkDomain, Trace,
TraceInfo, TracePolyTable, TraceTable,
};
use super::air::{PublicInputs, RescueAir, HASH_CYCLE_LEN};
use crate::{
hash::{rpo::Rpo256, STATE_WIDTH},
rand::RpoRandomCoin,
Word, ZERO,
};
// PROVER
// ================================================================================================
/// A prover for the RPO STARK-based signature scheme.
///
/// The signature is based on the the one-wayness of the RPO hash function but it is generic over
/// the hash function used for instantiating the random oracle for the BCS transform.
pub(crate) struct RpoSignatureProver<H: ElementHasher + Sync> {
message: Word,
options: ProofOptions,
_hasher: PhantomData<H>,
}
impl<H: ElementHasher + Sync> RpoSignatureProver<H> {
pub(crate) fn new(message: Word, options: ProofOptions) -> Self {
Self { message, options, _hasher: PhantomData }
}
pub(crate) fn build_trace(&self, sk: Word) -> TraceTable<BaseElement> {
let mut trace = TraceTable::new(STATE_WIDTH, HASH_CYCLE_LEN);
trace.fill(
|state| {
// initialize first half of the rate portion of the state with the secret key
state[0] = ZERO;
state[1] = ZERO;
state[2] = ZERO;
state[3] = ZERO;
state[4] = sk[0];
state[5] = sk[1];
state[6] = sk[2];
state[7] = sk[3];
state[8] = ZERO;
state[9] = ZERO;
state[10] = ZERO;
state[11] = ZERO;
},
|step, state| {
Rpo256::apply_round(
state.try_into().expect("should not fail given the size of the array"),
step,
);
},
);
trace
}
}
impl<H: ElementHasher> Prover for RpoSignatureProver<H>
where
H: ElementHasher<BaseField = BaseElement> + Sync,
{
type BaseField = BaseElement;
type Air = RescueAir;
type Trace = TraceTable<BaseElement>;
type HashFn = Rpo256;
type VC = SaltedMerkleTree<Self::HashFn, Self::ZkPrng>;
type RandomCoin = RpoRandomCoin;
type TraceLde<E: FieldElement<BaseField = Self::BaseField>> =
DefaultTraceLde<E, Self::HashFn, Self::VC>;
type ConstraintCommitment<E: FieldElement<BaseField = Self::BaseField>> =
DefaultConstraintCommitment<E, Self::HashFn, Self::ZkPrng, Self::VC>;
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Self::BaseField>> =
DefaultConstraintEvaluator<'a, Self::Air, E>;
type ZkPrng = ChaCha20Rng;
fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs {
let last_step = trace.length() - 1;
// Note that the message is not part of the execution trace but is part of the public
// inputs. This is explained in the reference description of the DSA and intuitively
// it is done in order to make sure that the message is part of the Fiat-Shamir
// transcript and hence binds the proof/signature to the message
PublicInputs {
pub_key: [
trace.get(4, last_step),
trace.get(5, last_step),
trace.get(6, last_step),
trace.get(7, last_step),
],
msg: self.message,
}
}
fn options(&self) -> &ProofOptions {
&self.options
}
fn new_trace_lde<E: FieldElement<BaseField = Self::BaseField>>(
&self,
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
domain: &StarkDomain<Self::BaseField>,
partition_option: PartitionOptions,
zk_parameters: Option<ZkParameters>,
prng: &mut Option<Self::ZkPrng>,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
DefaultTraceLde::new(trace_info, main_trace, domain, partition_option, zk_parameters, prng)
}
fn new_evaluator<'a, E: FieldElement<BaseField = Self::BaseField>>(
&self,
air: &'a Self::Air,
aux_rand_elements: Option<AuxRandElements<E>>,
composition_coefficients: ConstraintCompositionCoefficients<E>,
) -> Self::ConstraintEvaluator<'a, E> {
DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients)
}
fn build_constraint_commitment<E: FieldElement<BaseField = Self::BaseField>>(
&self,
composition_poly_trace: CompositionPolyTrace<E>,
num_constraint_composition_columns: usize,
domain: &StarkDomain<Self::BaseField>,
partition_options: PartitionOptions,
zk_parameters: Option<ZkParameters>,
prng: &mut Option<Self::ZkPrng>,
) -> (Self::ConstraintCommitment<E>, CompositionPoly<E>) {
DefaultConstraintCommitment::new(
composition_poly_trace,
num_constraint_composition_columns,
domain,
partition_options,
zk_parameters,
prng,
)
}
}

View File

@@ -1,12 +1,14 @@
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField};
use crate::utils::{
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
};
use alloc::{string::String, vec::Vec};
use core::{
mem::{size_of, transmute, transmute_copy},
ops::Deref,
slice::from_raw_parts,
slice::{self, from_raw_parts},
};
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher};
use crate::utils::{
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
};
#[cfg(test)]
@@ -31,6 +33,14 @@ const DIGEST20_BYTES: usize = 20;
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct Blake3Digest<const N: usize>([u8; N]);
impl<const N: usize> Blake3Digest<N> {
pub fn digests_as_bytes(digests: &[Blake3Digest<N>]) -> &[u8] {
let p = digests.as_ptr();
let len = digests.len() * N;
unsafe { slice::from_raw_parts(p as *const u8, len) }
}
}
impl<const N: usize> Default for Blake3Digest<N> {
fn default() -> Self {
Self([0; N])
@@ -112,6 +122,10 @@ impl Hasher for Blake3_256 {
Self::hash(prepare_merge(values))
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Blake3Digest(blake3::hash(Blake3Digest::digests_as_bytes(values)).into())
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0);
@@ -172,6 +186,11 @@ impl Hasher for Blake3_192 {
Blake3Digest(*shrink_bytes(&blake3::hash(bytes).into()))
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
Self::hash(prepare_merge(values))
}
@@ -240,6 +259,11 @@ impl Hasher for Blake3_160 {
Self::hash(prepare_merge(values))
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
let bytes: Vec<u8> = values.iter().flat_map(|v| v.as_bytes()).collect();
Blake3Digest(*shrink_bytes(&blake3::hash(&bytes).into()))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
let mut hasher = blake3::Hasher::new();
hasher.update(&seed.0);

View File

@@ -1,8 +1,10 @@
use super::*;
use crate::utils::collections::Vec;
use alloc::vec::Vec;
use proptest::prelude::*;
use rand_utils::rand_vector;
use super::*;
#[test]
fn blake3_hash_elements() {
// test multiple of 8

View File

@@ -1,9 +1,18 @@
//! Cryptographic hash functions used by the Miden VM and the Miden rollup.
use super::{Felt, FieldElement, StarkField, ONE, ZERO};
use super::{CubeExtension, Felt, FieldElement, StarkField, ZERO};
pub mod blake;
pub mod rpo;
mod rescue;
pub(crate) use rescue::{ARK1, ARK2, DIGEST_SIZE, MDS, STATE_WIDTH};
pub mod rpo {
pub use super::rescue::{Rpo256, RpoDigest, RpoDigestError};
}
pub mod rpx {
pub use super::rescue::{Rpx256, RpxDigest, RpxDigestError};
}
// RE-EXPORTS
// ================================================================================================

101
src/hash/rescue/arch/mod.rs Normal file
View File

@@ -0,0 +1,101 @@
#[cfg(target_feature = "sve")]
pub mod optimized {
use crate::{hash::rescue::STATE_WIDTH, Felt};
mod ffi {
#[link(name = "rpo_sve", kind = "static")]
extern "C" {
pub fn add_constants_and_apply_sbox(
state: *mut std::ffi::c_ulong,
constants: *const std::ffi::c_ulong,
) -> bool;
pub fn add_constants_and_apply_inv_sbox(
state: *mut std::ffi::c_ulong,
constants: *const std::ffi::c_ulong,
) -> bool;
}
}
#[inline(always)]
pub fn add_constants_and_apply_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
unsafe {
ffi::add_constants_and_apply_sbox(
state.as_mut_ptr() as *mut u64,
ark.as_ptr() as *const u64,
)
}
}
#[inline(always)]
pub fn add_constants_and_apply_inv_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
unsafe {
ffi::add_constants_and_apply_inv_sbox(
state.as_mut_ptr() as *mut u64,
ark.as_ptr() as *const u64,
)
}
}
}
#[cfg(target_feature = "avx2")]
mod x86_64_avx2;
#[cfg(target_feature = "avx2")]
pub mod optimized {
use super::x86_64_avx2::{apply_inv_sbox, apply_sbox};
use crate::{
hash::rescue::{add_constants, STATE_WIDTH},
Felt,
};
#[inline(always)]
pub fn add_constants_and_apply_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
add_constants(state, ark);
unsafe {
apply_sbox(std::mem::transmute(state));
}
true
}
#[inline(always)]
pub fn add_constants_and_apply_inv_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
add_constants(state, ark);
unsafe {
apply_inv_sbox(std::mem::transmute(state));
}
true
}
}
#[cfg(not(any(target_feature = "avx2", target_feature = "sve")))]
pub mod optimized {
use crate::{hash::rescue::STATE_WIDTH, Felt};
#[inline(always)]
pub fn add_constants_and_apply_sbox(
_state: &mut [Felt; STATE_WIDTH],
_ark: &[Felt; STATE_WIDTH],
) -> bool {
false
}
#[inline(always)]
pub fn add_constants_and_apply_inv_sbox(
_state: &mut [Felt; STATE_WIDTH],
_ark: &[Felt; STATE_WIDTH],
) -> bool {
false
}
}

View File

@@ -0,0 +1,328 @@
use core::arch::x86_64::*;
// The following AVX2 implementation has been copied from plonky2:
// https://github.com/0xPolygonZero/plonky2/blob/main/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2_bmi2.rs
// Preliminary notes:
// 1. AVX does not support addition with carry but 128-bit (2-word) addition can be easily emulated.
// The method recognizes that for a + b overflowed iff (a + b) < a:
// 1. res_lo = a_lo + b_lo
// 2. carry_mask = res_lo < a_lo
// 3. res_hi = a_hi + b_hi - carry_mask
//
// Notice that carry_mask is subtracted, not added. This is because AVX comparison instructions
// return -1 (all bits 1) for true and 0 for false.
//
// 2. AVX does not have unsigned 64-bit comparisons. Those can be emulated with signed comparisons
// by recognizing that a <u b iff a + (1 << 63) <s b + (1 << 63), where the addition wraps around
// and the comparisons are unsigned and signed respectively. The shift function adds/subtracts 1
// << 63 to enable this trick. Addition with carry example:
// 1. a_lo_s = shift(a_lo)
// 2. res_lo_s = a_lo_s + b_lo
// 3. carry_mask = res_lo_s <s a_lo_s
// 4. res_lo = shift(res_lo_s)
// 5. res_hi = a_hi + b_hi - carry_mask
//
// The suffix _s denotes a value that has been shifted by 1 << 63. The result of addition
// is shifted if exactly one of the operands is shifted, as is the case on
// line 2. Line 3. performs a signed comparison res_lo_s <s a_lo_s on shifted values to
// emulate unsigned comparison res_lo <u a_lo on unshifted values. Finally, line 4. reverses the
// shift so the result can be returned.
//
// When performing a chain of calculations, we can often save instructions by letting
// the shift propagate through and only undoing it when necessary.
// For example, to compute the addition of three two-word (128-bit) numbers we can do:
// 1. a_lo_s = shift(a_lo)
// 2. tmp_lo_s = a_lo_s + b_lo
// 3. tmp_carry_mask = tmp_lo_s <s a_lo_s
// 4. tmp_hi = a_hi + b_hi - tmp_carry_mask
// 5. res_lo_s = tmp_lo_s + c_lo vi. res_carry_mask = res_lo_s <s tmp_lo_s
// 6. res_carry_mask = res_lo_s <s tmp_lo_s
// 7. res_lo = shift(res_lo_s)
// 8. res_hi = tmp_hi + c_hi - res_carry_mask
//
// Notice that the above 3-value addition still only requires two calls to shift, just like our
// 2-value addition.
#[inline(always)]
pub fn branch_hint() {
// NOTE: These are the currently supported assembly architectures. See the
// [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for
// the most up-to-date list.
#[cfg(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "riscv32",
target_arch = "riscv64",
target_arch = "x86",
target_arch = "x86_64",
))]
unsafe {
core::arch::asm!("", options(nomem, nostack, preserves_flags));
}
}
macro_rules! map3 {
($f:ident:: < $l:literal > , $v:ident) => {
($f::<$l>($v.0), $f::<$l>($v.1), $f::<$l>($v.2))
};
($f:ident:: < $l:literal > , $v1:ident, $v2:ident) => {
($f::<$l>($v1.0, $v2.0), $f::<$l>($v1.1, $v2.1), $f::<$l>($v1.2, $v2.2))
};
($f:ident, $v:ident) => {
($f($v.0), $f($v.1), $f($v.2))
};
($f:ident, $v0:ident, $v1:ident) => {
($f($v0.0, $v1.0), $f($v0.1, $v1.1), $f($v0.2, $v1.2))
};
($f:ident,rep $v0:ident, $v1:ident) => {
($f($v0, $v1.0), $f($v0, $v1.1), $f($v0, $v1.2))
};
($f:ident, $v0:ident,rep $v1:ident) => {
($f($v0.0, $v1), $f($v0.1, $v1), $f($v0.2, $v1))
};
}
#[inline(always)]
unsafe fn square3(
x: (__m256i, __m256i, __m256i),
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
let x_hi = {
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
// This is safe and free.
let x_ps = map3!(_mm256_castsi256_ps, x);
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
map3!(_mm256_castps_si256, x_hi_ps)
};
// All pairwise multiplications.
let mul_ll = map3!(_mm256_mul_epu32, x, x);
let mul_lh = map3!(_mm256_mul_epu32, x, x_hi);
let mul_hh = map3!(_mm256_mul_epu32, x_hi, x_hi);
// Bignum addition, but mul_lh is shifted by 33 bits (not 32).
let mul_ll_hi = map3!(_mm256_srli_epi64::<33>, mul_ll);
let t0 = map3!(_mm256_add_epi64, mul_lh, mul_ll_hi);
let t0_hi = map3!(_mm256_srli_epi64::<31>, t0);
let res_hi = map3!(_mm256_add_epi64, mul_hh, t0_hi);
// Form low result by adding the mul_ll and the low 31 bits of mul_lh (shifted to the high
// position).
let mul_lh_lo = map3!(_mm256_slli_epi64::<33>, mul_lh);
let res_lo = map3!(_mm256_add_epi64, mul_ll, mul_lh_lo);
(res_lo, res_hi)
}
#[inline(always)]
unsafe fn mul3(
x: (__m256i, __m256i, __m256i),
y: (__m256i, __m256i, __m256i),
) -> ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)) {
let epsilon = _mm256_set1_epi64x(0xffffffff);
let x_hi = {
// Move high bits to low position. The high bits of x_hi are ignored. Swizzle is faster than
// bitshift. This instruction only has a floating-point flavor, so we cast to/from float.
// This is safe and free.
let x_ps = map3!(_mm256_castsi256_ps, x);
let x_hi_ps = map3!(_mm256_movehdup_ps, x_ps);
map3!(_mm256_castps_si256, x_hi_ps)
};
let y_hi = {
let y_ps = map3!(_mm256_castsi256_ps, y);
let y_hi_ps = map3!(_mm256_movehdup_ps, y_ps);
map3!(_mm256_castps_si256, y_hi_ps)
};
// All four pairwise multiplications
let mul_ll = map3!(_mm256_mul_epu32, x, y);
let mul_lh = map3!(_mm256_mul_epu32, x, y_hi);
let mul_hl = map3!(_mm256_mul_epu32, x_hi, y);
let mul_hh = map3!(_mm256_mul_epu32, x_hi, y_hi);
// Bignum addition
// Extract high 32 bits of mul_ll and add to mul_hl. This cannot overflow.
let mul_ll_hi = map3!(_mm256_srli_epi64::<32>, mul_ll);
let t0 = map3!(_mm256_add_epi64, mul_hl, mul_ll_hi);
// Extract low 32 bits of t0 and add to mul_lh. Again, this cannot overflow.
// Also, extract high 32 bits of t0 and add to mul_hh.
let t0_lo = map3!(_mm256_and_si256, t0, rep epsilon);
let t0_hi = map3!(_mm256_srli_epi64::<32>, t0);
let t1 = map3!(_mm256_add_epi64, mul_lh, t0_lo);
let t2 = map3!(_mm256_add_epi64, mul_hh, t0_hi);
// Lastly, extract the high 32 bits of t1 and add to t2.
let t1_hi = map3!(_mm256_srli_epi64::<32>, t1);
let res_hi = map3!(_mm256_add_epi64, t2, t1_hi);
// Form res_lo by combining the low half of mul_ll with the low half of t1 (shifted into high
// position).
let t1_lo = {
let t1_ps = map3!(_mm256_castsi256_ps, t1);
let t1_lo_ps = map3!(_mm256_moveldup_ps, t1_ps);
map3!(_mm256_castps_si256, t1_lo_ps)
};
let res_lo = map3!(_mm256_blend_epi32::<0xaa>, mul_ll, t1_lo);
(res_lo, res_hi)
}
/// Addition, where the second operand is `0 <= y < 0xffffffff00000001`.
#[inline(always)]
unsafe fn add_small(
x_s: (__m256i, __m256i, __m256i),
y: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
let res_wrapped_s = map3!(_mm256_add_epi64, x_s, y);
let mask = map3!(_mm256_cmpgt_epi32, x_s, res_wrapped_s);
let wrapback_amt = map3!(_mm256_srli_epi64::<32>, mask); // EPSILON if overflowed else 0.
let res_s = map3!(_mm256_add_epi64, res_wrapped_s, wrapback_amt);
res_s
}
#[inline(always)]
unsafe fn maybe_adj_sub(res_wrapped_s: __m256i, mask: __m256i) -> __m256i {
// The subtraction is very unlikely to overflow so we're best off branching.
// The even u32s in `mask` are meaningless, so we want to ignore them. `_mm256_testz_pd`
// branches depending on the sign bit of double-precision (64-bit) floats. Bit cast `mask` to
// floating-point (this is free).
let mask_pd = _mm256_castsi256_pd(mask);
// `_mm256_testz_pd(mask_pd, mask_pd) == 1` iff all sign bits are 0, meaning that underflow
// did not occur for any of the vector elements.
if _mm256_testz_pd(mask_pd, mask_pd) == 1 {
res_wrapped_s
} else {
branch_hint();
// Highly unlikely: underflow did occur. Find adjustment per element and apply it.
let adj_amount = _mm256_srli_epi64::<32>(mask); // EPSILON if underflow.
_mm256_sub_epi64(res_wrapped_s, adj_amount)
}
}
/// Addition, where the second operand is much smaller than `0xffffffff00000001`.
#[inline(always)]
unsafe fn sub_tiny(
x_s: (__m256i, __m256i, __m256i),
y: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
let res_wrapped_s = map3!(_mm256_sub_epi64, x_s, y);
let mask = map3!(_mm256_cmpgt_epi32, res_wrapped_s, x_s);
let res_s = map3!(maybe_adj_sub, res_wrapped_s, mask);
res_s
}
#[inline(always)]
unsafe fn reduce3(
(lo0, hi0): ((__m256i, __m256i, __m256i), (__m256i, __m256i, __m256i)),
) -> (__m256i, __m256i, __m256i) {
let sign_bit = _mm256_set1_epi64x(i64::MIN);
let epsilon = _mm256_set1_epi64x(0xffffffff);
let lo0_s = map3!(_mm256_xor_si256, lo0, rep sign_bit);
let hi_hi0 = map3!(_mm256_srli_epi64::<32>, hi0);
let lo1_s = sub_tiny(lo0_s, hi_hi0);
let t1 = map3!(_mm256_mul_epu32, hi0, rep epsilon);
let lo2_s = add_small(lo1_s, t1);
let lo2 = map3!(_mm256_xor_si256, lo2_s, rep sign_bit);
lo2
}
#[inline(always)]
unsafe fn mul_reduce(
a: (__m256i, __m256i, __m256i),
b: (__m256i, __m256i, __m256i),
) -> (__m256i, __m256i, __m256i) {
reduce3(mul3(a, b))
}
#[inline(always)]
unsafe fn square_reduce(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
reduce3(square3(state))
}
#[inline(always)]
unsafe fn exp_acc(
high: (__m256i, __m256i, __m256i),
low: (__m256i, __m256i, __m256i),
exp: usize,
) -> (__m256i, __m256i, __m256i) {
let mut result = high;
for _ in 0..exp {
result = square_reduce(result);
}
mul_reduce(result, low)
}
#[inline(always)]
unsafe fn do_apply_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
let state2 = square_reduce(state);
let state4_unreduced = square3(state2);
let state3_unreduced = mul3(state2, state);
let state4 = reduce3(state4_unreduced);
let state3 = reduce3(state3_unreduced);
let state7_unreduced = mul3(state3, state4);
let state7 = reduce3(state7_unreduced);
state7
}
#[inline(always)]
unsafe fn do_apply_inv_sbox(state: (__m256i, __m256i, __m256i)) -> (__m256i, __m256i, __m256i) {
// compute base^10540996611094048183 using 72 multiplications per array element
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
// compute base^10
let t1 = square_reduce(state);
// compute base^100
let t2 = square_reduce(t1);
// compute base^100100
let t3 = exp_acc(t2, t2, 3);
// compute base^100100100100
let t4 = exp_acc(t3, t3, 6);
// compute base^100100100100100100100100
let t5 = exp_acc(t4, t4, 12);
// compute base^100100100100100100100100100100
let t6 = exp_acc(t5, t3, 6);
// compute base^1001001001001001001001001001000100100100100100100100100100100
let t7 = exp_acc(t6, t6, 31);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
let a = square_reduce(square_reduce(mul_reduce(square_reduce(t7), t6)));
let b = mul_reduce(t1, mul_reduce(t2, state));
mul_reduce(a, b)
}
#[inline(always)]
unsafe fn avx2_load(state: &[u64; 12]) -> (__m256i, __m256i, __m256i) {
(
_mm256_loadu_si256((&state[0..4]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&state[4..8]).as_ptr().cast::<__m256i>()),
_mm256_loadu_si256((&state[8..12]).as_ptr().cast::<__m256i>()),
)
}
#[inline(always)]
unsafe fn avx2_store(buf: &mut [u64; 12], state: (__m256i, __m256i, __m256i)) {
_mm256_storeu_si256((&mut buf[0..4]).as_mut_ptr().cast::<__m256i>(), state.0);
_mm256_storeu_si256((&mut buf[4..8]).as_mut_ptr().cast::<__m256i>(), state.1);
_mm256_storeu_si256((&mut buf[8..12]).as_mut_ptr().cast::<__m256i>(), state.2);
}
#[inline(always)]
pub unsafe fn apply_sbox(buffer: &mut [u64; 12]) {
let mut state = avx2_load(&buffer);
state = do_apply_sbox(state);
avx2_store(buffer, state);
}
#[inline(always)]
pub unsafe fn apply_inv_sbox(buffer: &mut [u64; 12]) {
let mut state = avx2_load(&buffer);
state = do_apply_inv_sbox(state);
avx2_store(buffer, state);
}

View File

@@ -1,32 +1,35 @@
// FFT-BASED MDS MULTIPLICATION HELPER FUNCTIONS
// ================================================================================================
/// This module contains helper functions as well as constants used to perform the vector-matrix
/// multiplication step of the Rescue prime permutation. The special form of our MDS matrix
/// i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
/// of two vectors in "frequency domain". This follows from the simple fact that every circulant
/// matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
/// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
/// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain,
/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of
/// an MDS matrix that has small powers of 2 entries in frequency domain.
/// The following implementation has benefited greatly from the discussions and insights of
/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero.
//! This module contains helper functions as well as constants used to perform the vector-matrix
//! multiplication step of the Rescue prime permutation. The special form of our MDS matrix
//! i.e. being circular, allows us to reduce the vector-matrix multiplication to a Hadamard product
//! of two vectors in "frequency domain". This follows from the simple fact that every circulant
//! matrix has the columns of the discrete Fourier transform matrix as orthogonal eigenvectors.
//! The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that
//! with explicit expressions. It also avoids, due to the form of our matrix in the frequency
//! domain, divisions by 2 and repeated modular reductions. This is because of our explicit choice
//! of an MDS matrix that has small powers of 2 entries in frequency domain.
//! The following implementation has benefited greatly from the discussions and insights of
//! Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2
//! implementation.
// Rescue MDS matrix in frequency domain.
//
// More precisely, this is the output of the three 4-point (real) FFTs of the first column of
// the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors
// and application of the final four 3-point FFT in order to get the full 12-point FFT.
// The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4.
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to generate
// MDS matrices efficiently in original domain, that was developed by the Polygon Zero team.
// The code to generate the matrix in frequency domain is based on an adaptation of a code, to
// generate MDS matrices efficiently in original domain, that was developed by the Polygon Zero
// team.
const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16];
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)];
const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1];
// We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain.
#[inline(always)]
pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] {
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state;
let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]);
@@ -156,9 +159,10 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] {
#[cfg(test)]
mod tests {
use super::super::{Felt, Rpo256, MDS, ZERO};
use proptest::prelude::*;
use super::super::{apply_mds, Felt, MDS, ZERO};
const STATE_WIDTH: usize = 12;
#[inline(always)]
@@ -185,7 +189,7 @@ mod tests {
v2 = v1;
apply_mds_naive(&mut v1);
Rpo256::apply_mds(&mut v2);
apply_mds(&mut v2);
prop_assert_eq!(v1, v2);
}

214
src/hash/rescue/mds/mod.rs Normal file
View File

@@ -0,0 +1,214 @@
use super::{Felt, STATE_WIDTH, ZERO};
mod freq;
pub use freq::mds_multiply_freq;
// MDS MULTIPLICATION
// ================================================================================================
#[inline(always)]
pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
let mut result = [ZERO; STATE_WIDTH];
// Using the linearity of the operations we can split the state into a low||high decomposition
// and operate on each with no overflow and then combine/reduce the result to a field element.
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
// frequency domain.
let mut state_l = [0u64; STATE_WIDTH];
let mut state_h = [0u64; STATE_WIDTH];
for r in 0..STATE_WIDTH {
let s = state[r].inner();
state_h[r] = s >> 32;
state_l[r] = (s as u32) as u64;
}
let state_h = mds_multiply_freq(state_h);
let state_l = mds_multiply_freq(state_l);
for r in 0..STATE_WIDTH {
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
let s_hi = (s >> 64) as u64;
let s_lo = s as u64;
let z = (s_hi << 32) - s_hi;
let (res, over) = s_lo.overflowing_add(z);
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
}
*state = result;
}
// MDS MATRIX
// ================================================================================================
/// RPO MDS matrix
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
[
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
],
[
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
],
[
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
],
[
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
],
[
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
],
[
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
],
[
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
],
[
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
],
[
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
],
[
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
],
[
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
],
[
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
],
];

347
src/hash/rescue/mod.rs Normal file
View File

@@ -0,0 +1,347 @@
use core::ops::Range;
use super::{CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ZERO};
mod arch;
pub use arch::optimized::{add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox};
mod mds;
pub(crate) use mds::{apply_mds, MDS};
mod rpo;
pub use rpo::{Rpo256, RpoDigest, RpoDigestError};
mod rpx;
pub use rpx::{Rpx256, RpxDigest, RpxDigestError};
#[cfg(test)]
mod tests;
// CONSTANTS
// ================================================================================================
/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the
/// RPX hash function, there are 3 different types of rounds.
const NUM_ROUNDS: usize = 7;
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
pub(crate) const STATE_WIDTH: usize = 12;
/// The rate portion of the state is located in elements 4 through 11.
const RATE_RANGE: Range<usize> = 4..12;
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
const INPUT1_RANGE: Range<usize> = 4..8;
const INPUT2_RANGE: Range<usize> = 8..12;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
const CAPACITY_RANGE: Range<usize> = 0..4;
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
///
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
/// rate portion).
pub(crate) const DIGEST_RANGE: Range<usize> = 4..8;
pub(crate) const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
/// The number of bytes needed to encoded a digest
const DIGEST_BYTES: usize = 32;
/// The number of byte chunks defining a field element when hashing a sequence of bytes
const BINARY_CHUNK_SIZE: usize = 7;
/// S-Box and Inverse S-Box powers;
///
/// The constants are defined for tests only because the exponentiations in the code are unrolled
/// for efficiency reasons.
#[cfg(test)]
const ALPHA: u64 = 7;
#[cfg(test)]
const INV_ALPHA: u64 = 10540996611094048183;
// SBOX FUNCTION
// ================================================================================================
#[inline(always)]
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
state[0] = state[0].exp7();
state[1] = state[1].exp7();
state[2] = state[2].exp7();
state[3] = state[3].exp7();
state[4] = state[4].exp7();
state[5] = state[5].exp7();
state[6] = state[6].exp7();
state[7] = state[7].exp7();
state[8] = state[8].exp7();
state[9] = state[9].exp7();
state[10] = state[10].exp7();
state[11] = state[11].exp7();
}
// INVERSE SBOX FUNCTION
// ================================================================================================
#[inline(always)]
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
// compute base^10540996611094048183 using 72 multiplications per array element
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
// compute base^10
let mut t1 = *state;
t1.iter_mut().for_each(|t| *t = t.square());
// compute base^100
let mut t2 = t1;
t2.iter_mut().for_each(|t| *t = t.square());
// compute base^100100
let t3 = exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
// compute base^100100100100
let t4 = exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
// compute base^100100100100100100100100
let t5 = exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
// compute base^100100100100100100100100100100
let t6 = exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
// compute base^1001001001001001001001001001000100100100100100100100100100100
let t7 = exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
for (i, s) in state.iter_mut().enumerate() {
let a = (t7[i].square() * t6[i]).square().square();
let b = t1[i] * t2[i] * *s;
*s = a * b;
}
#[inline(always)]
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
base: [B; N],
tail: [B; N],
) -> [B; N] {
let mut result = base;
for _ in 0..M {
result.iter_mut().for_each(|r| *r = r.square());
}
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
result
}
}
#[inline(always)]
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
}
// ROUND CONSTANTS
// ================================================================================================
/// Rescue round constants;
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
///
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
pub(crate) const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[
Felt::new(5789762306288267392),
Felt::new(6522564764413701783),
Felt::new(17809893479458208203),
Felt::new(107145243989736508),
Felt::new(6388978042437517382),
Felt::new(15844067734406016715),
Felt::new(9975000513555218239),
Felt::new(3344984123768313364),
Felt::new(9959189626657347191),
Felt::new(12960773468763563665),
Felt::new(9602914297752488475),
Felt::new(16657542370200465908),
],
[
Felt::new(12987190162843096997),
Felt::new(653957632802705281),
Felt::new(4441654670647621225),
Felt::new(4038207883745915761),
Felt::new(5613464648874830118),
Felt::new(13222989726778338773),
Felt::new(3037761201230264149),
Felt::new(16683759727265180203),
Felt::new(8337364536491240715),
Felt::new(3227397518293416448),
Felt::new(8110510111539674682),
Felt::new(2872078294163232137),
],
[
Felt::new(18072785500942327487),
Felt::new(6200974112677013481),
Felt::new(17682092219085884187),
Felt::new(10599526828986756440),
Felt::new(975003873302957338),
Felt::new(8264241093196931281),
Felt::new(10065763900435475170),
Felt::new(2181131744534710197),
Felt::new(6317303992309418647),
Felt::new(1401440938888741532),
Felt::new(8884468225181997494),
Felt::new(13066900325715521532),
],
[
Felt::new(5674685213610121970),
Felt::new(5759084860419474071),
Felt::new(13943282657648897737),
Felt::new(1352748651966375394),
Felt::new(17110913224029905221),
Felt::new(1003883795902368422),
Felt::new(4141870621881018291),
Felt::new(8121410972417424656),
Felt::new(14300518605864919529),
Felt::new(13712227150607670181),
Felt::new(17021852944633065291),
Felt::new(6252096473787587650),
],
[
Felt::new(4887609836208846458),
Felt::new(3027115137917284492),
Felt::new(9595098600469470675),
Felt::new(10528569829048484079),
Felt::new(7864689113198939815),
Felt::new(17533723827845969040),
Felt::new(5781638039037710951),
Felt::new(17024078752430719006),
Felt::new(109659393484013511),
Felt::new(7158933660534805869),
Felt::new(2955076958026921730),
Felt::new(7433723648458773977),
],
[
Felt::new(16308865189192447297),
Felt::new(11977192855656444890),
Felt::new(12532242556065780287),
Felt::new(14594890931430968898),
Felt::new(7291784239689209784),
Felt::new(5514718540551361949),
Felt::new(10025733853830934803),
Felt::new(7293794580341021693),
Felt::new(6728552937464861756),
Felt::new(6332385040983343262),
Felt::new(13277683694236792804),
Felt::new(2600778905124452676),
],
[
Felt::new(7123075680859040534),
Felt::new(1034205548717903090),
Felt::new(7717824418247931797),
Felt::new(3019070937878604058),
Felt::new(11403792746066867460),
Felt::new(10280580802233112374),
Felt::new(337153209462421218),
Felt::new(13333398568519923717),
Felt::new(3596153696935337464),
Felt::new(8104208463525993784),
Felt::new(14345062289456085693),
Felt::new(17036731477169661256),
],
];
pub(crate) const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[
Felt::new(6077062762357204287),
Felt::new(15277620170502011191),
Felt::new(5358738125714196705),
Felt::new(14233283787297595718),
Felt::new(13792579614346651365),
Felt::new(11614812331536767105),
Felt::new(14871063686742261166),
Felt::new(10148237148793043499),
Felt::new(4457428952329675767),
Felt::new(15590786458219172475),
Felt::new(10063319113072092615),
Felt::new(14200078843431360086),
],
[
Felt::new(6202948458916099932),
Felt::new(17690140365333231091),
Felt::new(3595001575307484651),
Felt::new(373995945117666487),
Felt::new(1235734395091296013),
Felt::new(14172757457833931602),
Felt::new(707573103686350224),
Felt::new(15453217512188187135),
Felt::new(219777875004506018),
Felt::new(17876696346199469008),
Felt::new(17731621626449383378),
Felt::new(2897136237748376248),
],
[
Felt::new(8023374565629191455),
Felt::new(15013690343205953430),
Felt::new(4485500052507912973),
Felt::new(12489737547229155153),
Felt::new(9500452585969030576),
Felt::new(2054001340201038870),
Felt::new(12420704059284934186),
Felt::new(355990932618543755),
Felt::new(9071225051243523860),
Felt::new(12766199826003448536),
Felt::new(9045979173463556963),
Felt::new(12934431667190679898),
],
[
Felt::new(18389244934624494276),
Felt::new(16731736864863925227),
Felt::new(4440209734760478192),
Felt::new(17208448209698888938),
Felt::new(8739495587021565984),
Felt::new(17000774922218161967),
Felt::new(13533282547195532087),
Felt::new(525402848358706231),
Felt::new(16987541523062161972),
Felt::new(5466806524462797102),
Felt::new(14512769585918244983),
Felt::new(10973956031244051118),
],
[
Felt::new(6982293561042362913),
Felt::new(14065426295947720331),
Felt::new(16451845770444974180),
Felt::new(7139138592091306727),
Felt::new(9012006439959783127),
Felt::new(14619614108529063361),
Felt::new(1394813199588124371),
Felt::new(4635111139507788575),
Felt::new(16217473952264203365),
Felt::new(10782018226466330683),
Felt::new(6844229992533662050),
Felt::new(7446486531695178711),
],
[
Felt::new(3736792340494631448),
Felt::new(577852220195055341),
Felt::new(6689998335515779805),
Felt::new(13886063479078013492),
Felt::new(14358505101923202168),
Felt::new(7744142531772274164),
Felt::new(16135070735728404443),
Felt::new(12290902521256031137),
Felt::new(12059913662657709804),
Felt::new(16456018495793751911),
Felt::new(4571485474751953524),
Felt::new(17200392109565783176),
],
[
Felt::new(17130398059294018733),
Felt::new(519782857322261988),
Felt::new(9625384390925085478),
Felt::new(1664893052631119222),
Felt::new(7629576092524553570),
Felt::new(3485239601103661425),
Felt::new(9755891797164033838),
Felt::new(15218148195153269027),
Felt::new(16460604813734957368),
Felt::new(9643968136937729763),
Felt::new(3611348709641382851),
Felt::new(18256379591337759196),
],
];

View File

@@ -0,0 +1,650 @@
use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use rand::{
distributions::{Standard, Uniform},
prelude::Distribution,
};
use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
use crate::{
rand::Randomizable,
utils::{
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
},
};
// DIGEST TRAIT IMPLEMENTATIONS
// ================================================================================================
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct RpoDigest([Felt; DIGEST_SIZE]);
impl RpoDigest {
/// The serialized size of the digest in bytes.
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
pub fn as_elements(&self) -> &[Felt] {
self.as_ref()
}
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
<Self as Digest>::as_bytes(self)
}
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where
I: Iterator<Item = &'a Self>,
{
digests.flat_map(|d| d.0.iter())
}
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
}
/// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes())
}
}
impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
result
}
}
impl Deref for RpoDigest {
type Target = [Felt; DIGEST_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Ord for RpoDigest {
fn cmp(&self, other: &Self) -> Ordering {
// compare the inner u64 of both elements.
//
// it will iterate the elements and will return the first computation different than
// `Equal`. Otherwise, the ordering is equal.
//
// the endianness is irrelevant here because since, this being a cryptographically secure
// hash computation, the digest shouldn't have any ordered property of its input.
//
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
// montgomery reduction for every limb. that is safe because every inner element of the
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
Ordering::Equal,
|ord, (a, b)| match ord {
Ordering::Equal => a.cmp(&b),
_ => ord,
},
)
}
}
impl PartialOrd for RpoDigest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Display for RpoDigest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let encoded: String = self.into();
write!(f, "{}", encoded)?;
Ok(())
}
}
impl Randomizable for RpoDigest {
const VALUE_SIZE: usize = DIGEST_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
if let Some(bytes_array) = bytes_array {
Self::try_from(bytes_array).ok()
} else {
None
}
}
}
impl Distribution<RpoDigest> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> RpoDigest {
let mut res = [ZERO; DIGEST_SIZE];
let uni_dist = Uniform::from(0..Felt::MODULUS);
for r in res.iter_mut() {
let sampled_integer = uni_dist.sample(rng);
*r = Felt::new(sampled_integer);
}
RpoDigest::new(res)
}
}
// CONVERSIONS: FROM RPO DIGEST
// ================================================================================================
#[derive(Debug, Error)]
pub enum RpoDigestError {
#[error("failed to convert digest field element to {0}")]
TypeConversion(&'static str),
#[error("failed to convert to field element: {0}")]
InvalidFieldElement(String),
}
impl TryFrom<&RpoDigest> for [bool; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [bool; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
fn to_bool(v: u64) -> Option<bool> {
if v <= 1 {
Some(v == 1)
} else {
None
}
}
Ok([
to_bool(value.0[0].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[1].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[2].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
to_bool(value.0[3].as_int()).ok_or(RpoDigestError::TypeConversion("bool"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u8; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u8; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u8"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u16; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u16; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u16"))?,
])
}
}
impl TryFrom<&RpoDigest> for [u32; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: &RpoDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpoDigest> for [u32; DIGEST_SIZE] {
type Error = RpoDigestError;
fn try_from(value: RpoDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpoDigestError::TypeConversion("u32"))?,
])
}
}
impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
[
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
}
}
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
value.0
}
}
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: RpoDigest) -> Self {
value.as_bytes()
}
}
impl From<&RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
impl From<RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpoDigest) -> Self {
value.to_hex()
}
}
// CONVERSIONS: TO RPO DIGEST
// ================================================================================================
impl From<&[bool; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[bool; DIGEST_SIZE]> for RpoDigest {
fn from(value: [bool; DIGEST_SIZE]) -> Self {
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
}
}
impl From<&[u8; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u8; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u8; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u16; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u16; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u16; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u32; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u32; DIGEST_SIZE]> for RpoDigest {
fn from(value: [u32; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
(*value).try_into()
}
}
impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest {
type Error = RpoDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpoDigestError> {
Ok(Self([
value[0].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[1].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[2].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
value[3].try_into().map_err(RpoDigestError::InvalidFieldElement)?,
]))
}
}
impl From<&[Felt; DIGEST_SIZE]> for RpoDigest {
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
Self(*value)
}
}
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
}
impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
// Note: the input length is known, the conversion from slice to array must succeed so the
// `unwrap`s below are safe
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
return Err(HexParseError::OutOfRange);
}
Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
}
}
impl TryFrom<&[u8]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<&str> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpoDigest::try_from)
}
}
impl TryFrom<String> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl TryFrom<&String> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for RpoDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes());
}
fn get_size_hint(&self) -> usize {
Self::SERIALIZED_SIZE
}
}
impl Deserializable for RpoDigest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
for inner in inner.iter_mut() {
let e = source.read_u64()?;
if e >= Felt::MODULUS {
return Err(DeserializationError::InvalidValue(String::from(
"Value not in the appropriate range",
)));
}
*inner = Felt::new(e);
}
Ok(Self(inner))
}
}
// ITERATORS
// ================================================================================================
impl IntoIterator for RpoDigest {
type Item = Felt;
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use alloc::string::String;
use rand_utils::rand_value;
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
use crate::utils::SliceReader;
#[test]
fn digest_serialization() {
let e1 = Felt::new(rand_value());
let e2 = Felt::new(rand_value());
let e3 = Felt::new(rand_value());
let e4 = Felt::new(rand_value());
let d1 = RpoDigest([e1, e2, e3, e4]);
let mut bytes = vec![];
d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len());
assert_eq!(bytes.len(), d1.get_size_hint());
let mut reader = SliceReader::new(&bytes);
let d2 = RpoDigest::read_from(&mut reader).unwrap();
assert_eq!(d1, d2);
}
#[test]
fn digest_encoding() {
let digest = RpoDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
let string: String = digest.into();
let round_trip: RpoDigest = string.try_into().expect("decoding failed");
assert_eq!(digest, round_trip);
}
#[test]
fn test_conversions() {
let digest = RpoDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
// BY VALUE
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpoDigest = v.into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpoDigest = v.into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u64; DIGEST_SIZE] = digest.into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = digest.into();
let v2: RpoDigest = v.into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: String = digest.into();
let v2: RpoDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
// BY REF
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpoDigest = (&v).into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u64; DIGEST_SIZE] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = (&digest).into();
let v2: RpoDigest = (&v).into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: String = (&digest).into();
let v2: RpoDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
}
}

339
src/hash/rescue/rpo/mod.rs Normal file
View File

@@ -0,0 +1,339 @@
use core::ops::Range;
use super::{
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
apply_mds, apply_sbox, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1,
ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE,
INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO,
};
mod digest;
pub use digest::{RpoDigest, RpoDigestError};
#[cfg(test)]
mod tests;
// HASHER IMPLEMENTATION
// ================================================================================================
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
///
/// The hash function is implemented according to the Rescue Prime Optimized
/// [specifications](https://eprint.iacr.org/2022/1577) while the padding rule follows the one
/// described [here](https://eprint.iacr.org/2023/1045).
///
/// The parameters used to instantiate the function are:
/// * Field: 64-bit prime field with modulus p = 2^64 - 2^32 + 1.
/// * State width: 12 field elements.
/// * Rate size: r = 8 field elements.
/// * Capacity size: c = 4 field elements.
/// * Number of founds: 7.
/// * S-Box degree: 7.
///
/// The above parameters target a 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
/// a hash for the same set of elements using these functions will always produce the same
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
/// same result as hashing 8 elements which make up these digests using
/// [hash_elements()](Rpo256::hash_elements) function.
///
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
/// For example, if we take two field elements, serialize them to bytes and hash them using
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
/// deserialization procedure used by this function is different from the procedure used to
/// deserialize valid field elements.
///
/// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpo256::hash_elements) function rather than hashing the serialized bytes
/// using [hash()](Rpo256::hash) function.
///
/// ## Domain separation
/// [merge_in_domain()](Rpo256::merge_in_domain) hashes two digests into one digest with some domain
/// identifier and the current implementation sets the second capacity element to the value of
/// this domain identifier. Using a similar argument to the one formulated for domain separation of
/// the RPX hash function in Appendix C of its [specification](https://eprint.iacr.org/2023/1045),
/// one sees that doing so degrades only pre-image resistance, from its initial bound of c.log_2(p),
/// by as much as the log_2 of the size of the domain identifier space. Since pre-image resistance
/// becomes the bottleneck for the security bound of the sponge in overwrite-mode only when it is
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
/// the size of the domain identifier space, including for padding, is less than 2^128.
///
/// ## Hashing of empty input
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
/// the benefit of requiring no calls to the RPO permutation when hashing empty input.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpo256();
impl Hasher for Rpo256 {
/// Rpo256 collision resistance is 128-bits.
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpoDigest;
fn hash(bytes: &[u8]) -> Self::Digest {
// initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH];
// determine the number of field elements needed to encode `bytes` when each field element
// represents at most 7 bytes.
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
// this to achieve:
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
state[CAPACITY_RANGE.start] =
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
// initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8];
// iterate the chunks of bytes, creating a field element from each chunk and copying it
// into the state.
//
// every time the rate range is filled, a permutation is performed. if the final value of
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
// and an additional permutation must be performed.
let mut current_chunk_idx = 0_usize;
// handle the case of an empty `bytes`
let last_chunk_idx = if num_field_elem == 0 {
current_chunk_idx
} else {
num_field_elem - 1
};
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
// copy the chunk into the buffer
if current_chunk_idx != last_chunk_idx {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
// needed to fill it
buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1;
}
current_chunk_idx += 1;
// set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range.
if rate_pos == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state);
0
} else {
rate_pos + 1
}
});
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
// don't need to apply any extra padding because the first capacity element contains a
// flag indicating the number of field elements constituting the last block when the latter
// is not divisible by `RATE_WIDTH`.
if rate_pos != 0 {
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the rate as hash result.
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(Self::Digest::digests_as_elements(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element and
// set the first capacity element to 5.
// - if the value doesn't fit into a single field element, split it into two field elements,
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS {
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
} else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
}
// apply the RPO permutation and return the first four elements of the rate
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
impl ElementHasher for Rpo256 {
type BaseField = Felt;
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
// convert the elements into a list of base field elements
let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which
// is set to `elements.len() % RATE_WIDTH`.
let mut state = [ZERO; STATE_WIDTH];
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
// absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all
// elements have been absorbed
let mut i = 0;
for &element in elements.iter() {
state[RATE_RANGE.start + i] = element;
i += 1;
if i % RATE_WIDTH == 0 {
Self::apply_permutation(&mut state);
i = 0;
}
}
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
if i > 0 {
while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO;
i += 1;
}
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the state as hash result
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
// HASH FUNCTION IMPLEMENTATION
// ================================================================================================
impl Rpo256 {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The number of rounds is set to 7 to target 128-bit security level.
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
pub const STATE_WIDTH: usize = STATE_WIDTH;
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
/// MDS matrix used for computing the linear layer in a RPO round.
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
/// Round constants added to the hasher state in the first half of the RPO round.
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
/// Round constants added to the hasher state in the second half of the RPO round.
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
// TRAIT PASS-THROUGH FUNCTIONS
// --------------------------------------------------------------------------------------------
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> RpoDigest {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
<Self as ElementHasher>::hash_elements(elements)
}
// DOMAIN IDENTIFIER
// --------------------------------------------------------------------------------------------
/// Returns a hash of two digests and a domain identifier.
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = RpoDigest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// set the second capacity element to the domain value. The first capacity element is used
// for padding purposes.
state[CAPACITY_RANGE.start + 1] = domain;
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
// RESCUE PERMUTATION
// --------------------------------------------------------------------------------------------
/// Applies RPO permutation to the provided state.
#[inline(always)]
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
for i in 0..NUM_ROUNDS {
Self::apply_round(state, i);
}
}
/// RPO round function.
#[inline(always)]
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
// apply first half of RPO round
apply_mds(state);
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
add_constants(state, &ARK1[round]);
apply_sbox(state);
}
// apply second half of RPO round
apply_mds(state);
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
add_constants(state, &ARK2[round]);
apply_inv_sbox(state);
}
}
}

View File

@@ -1,21 +1,16 @@
use super::{
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH,
ZERO,
};
use crate::{
utils::collections::{BTreeSet, Vec},
Word,
};
use core::convert::TryInto;
use alloc::{collections::BTreeSet, vec::Vec};
use proptest::prelude::*;
use rand_utils::rand_value;
#[test]
fn test_alphas() {
let e: Felt = Felt::new(rand_value());
let e_exp = e.exp(ALPHA);
assert_eq!(e, e_exp.exp(INV_ALPHA));
}
use super::{
super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA},
Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, STATE_WIDTH, ZERO,
};
use crate::{
hash::rescue::{BINARY_CHUNK_SIZE, CAPACITY_RANGE, RATE_WIDTH},
Word, ONE,
};
#[test]
fn test_sbox() {
@@ -25,7 +20,7 @@ fn test_sbox() {
expected.iter_mut().for_each(|v| *v = v.exp(ALPHA));
let mut actual = state;
Rpo256::apply_sbox(&mut actual);
apply_sbox(&mut actual);
assert_eq!(expected, actual);
}
@@ -38,7 +33,7 @@ fn test_inv_sbox() {
expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA));
let mut actual = state;
Rpo256::apply_inv_sbox(&mut actual);
apply_inv_sbox(&mut actual);
assert_eq!(expected, actual);
}
@@ -67,7 +62,7 @@ fn merge_vs_merge_in_domain() {
];
let merge_result = Rpo256::merge(&digests);
// ------------- merge with domain = 0 ----------------------------------------------------------
// ------------- merge with domain = 0 -------------
// set domain to ZERO. This should not change the result.
let domain = ZERO;
@@ -75,7 +70,7 @@ fn merge_vs_merge_in_domain() {
let merge_in_domain_result = Rpo256::merge_in_domain(&digests, domain);
assert_eq!(merge_result, merge_in_domain_result);
// ------------- merge with domain = 1 ----------------------------------------------------------
// ------------- merge with domain = 1 -------------
// set domain to ONE. This should change the result.
let domain = ONE;
@@ -134,6 +129,27 @@ fn hash_padding() {
assert_ne!(r1, r2);
}
#[test]
fn hash_padding_no_extra_permutation_call() {
use crate::hash::rescue::DIGEST_RANGE;
// Implementation
let num_bytes = BINARY_CHUNK_SIZE * RATE_WIDTH;
let mut buffer = vec![0_u8; num_bytes];
*buffer.last_mut().unwrap() = 97;
let r1 = Rpo256::hash(&buffer);
// Expected
let final_chunk = [0_u8, 0, 0, 0, 0, 0, 97, 1];
let mut state = [ZERO; STATE_WIDTH];
// padding when hashing bytes
state[CAPACITY_RANGE.start] = Felt::from(RATE_WIDTH as u8);
*state.last_mut().unwrap() = Felt::new(u64::from_le_bytes(final_chunk));
Rpo256::apply_permutation(&mut state);
assert_eq!(&r1[0..4], &state[DIGEST_RANGE]);
}
#[test]
fn hash_elements_padding() {
let e1 = [Felt::new(rand_value()); 2];
@@ -167,6 +183,24 @@ fn hash_elements() {
assert_eq!(m_result, h_result);
}
#[test]
fn hash_empty() {
let elements: Vec<Felt> = vec![];
let zero_digest = RpoDigest::default();
let h_result = Rpo256::hash_elements(&elements);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_empty_bytes() {
let bytes: Vec<u8> = vec![];
let zero_digest = RpoDigest::default();
let h_result = Rpo256::hash(&bytes);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_test_vectors() {
let elements = [
@@ -237,46 +271,46 @@ proptest! {
const EXPECTED: [Word; 19] = [
[
Felt::new(1502364727743950833),
Felt::new(5880949717274681448),
Felt::new(162790463902224431),
Felt::new(6901340476773664264),
Felt::new(18126731724905382595),
Felt::new(7388557040857728717),
Felt::new(14290750514634285295),
Felt::new(7852282086160480146),
],
[
Felt::new(7478710183745780580),
Felt::new(3308077307559720969),
Felt::new(3383561985796182409),
Felt::new(17205078494700259815),
Felt::new(10139303045932500183),
Felt::new(2293916558361785533),
Felt::new(15496361415980502047),
Felt::new(17904948502382283940),
],
[
Felt::new(17439912364295172999),
Felt::new(17979156346142712171),
Felt::new(8280795511427637894),
Felt::new(9349844417834368814),
Felt::new(17457546260239634015),
Felt::new(803990662839494686),
Felt::new(10386005777401424878),
Felt::new(18168807883298448638),
],
[
Felt::new(5105868198472766874),
Felt::new(13090564195691924742),
Felt::new(1058904296915798891),
Felt::new(18379501748825152268),
Felt::new(13072499238647455740),
Felt::new(10174350003422057273),
Felt::new(9201651627651151113),
Felt::new(6872461887313298746),
],
[
Felt::new(9133662113608941286),
Felt::new(12096627591905525991),
Felt::new(14963426595993304047),
Felt::new(13290205840019973377),
Felt::new(2903803350580990546),
Felt::new(1838870750730563299),
Felt::new(4258619137315479708),
Felt::new(17334260395129062936),
],
[
Felt::new(3134262397541159485),
Felt::new(10106105871979362399),
Felt::new(138768814855329459),
Felt::new(15044809212457404677),
Felt::new(8571221005243425262),
Felt::new(3016595589318175865),
Felt::new(13933674291329928438),
Felt::new(678640375034313072),
],
[
Felt::new(162696376578462826),
Felt::new(4991300494838863586),
Felt::new(660346084748120605),
Felt::new(13179389528641752698),
Felt::new(16314113978986502310),
Felt::new(14587622368743051587),
Felt::new(2808708361436818462),
Felt::new(10660517522478329440),
],
[
Felt::new(2242391899857912644),
@@ -285,46 +319,46 @@ const EXPECTED: [Word; 19] = [
Felt::new(5046143039268215739),
],
[
Felt::new(9585630502158073976),
Felt::new(1310051013427303477),
Felt::new(7491921222636097758),
Felt::new(9417501558995216762),
Felt::new(5218076004221736204),
Felt::new(17169400568680971304),
Felt::new(8840075572473868990),
Felt::new(12382372614369863623),
],
[
Felt::new(1994394001720334744),
Felt::new(10866209900885216467),
Felt::new(13836092831163031683),
Felt::new(10814636682252756697),
Felt::new(9783834557155203486),
Felt::new(12317263104955018849),
Felt::new(3933748931816109604),
Felt::new(1843043029836917214),
],
[
Felt::new(17486854790732826405),
Felt::new(17376549265955727562),
Felt::new(2371059831956435003),
Felt::new(17585704935858006533),
Felt::new(14498234468286984551),
Felt::new(16837257669834682387),
Felt::new(6664141123711355107),
Felt::new(4590460158294697186),
],
[
Felt::new(11368277489137713825),
Felt::new(3906270146963049287),
Felt::new(10236262408213059745),
Felt::new(78552867005814007),
Felt::new(4661800562479916067),
Felt::new(11794407552792839953),
Felt::new(9037742258721863712),
Felt::new(6287820818064278819),
],
[
Felt::new(17899847381280262181),
Felt::new(14717912805498651446),
Felt::new(10769146203951775298),
Felt::new(2774289833490417856),
Felt::new(7752693085194633729),
Felt::new(7379857372245835536),
Felt::new(9270229380648024178),
Felt::new(10638301488452560378),
],
[
Felt::new(3794717687462954368),
Felt::new(4386865643074822822),
Felt::new(8854162840275334305),
Felt::new(7129983987107225269),
Felt::new(11542686762698783357),
Felt::new(15570714990728449027),
Felt::new(7518801014067819501),
Felt::new(12706437751337583515),
],
[
Felt::new(7244773535611633983),
Felt::new(19359923075859320),
Felt::new(10898655967774994333),
Felt::new(9319339563065736480),
Felt::new(9553923701032839042),
Felt::new(7281190920209838818),
Felt::new(2488477917448393955),
Felt::new(5088955350303368837),
],
[
Felt::new(4935426252518736883),
@@ -333,21 +367,21 @@ const EXPECTED: [Word; 19] = [
Felt::new(18159875708229758073),
],
[
Felt::new(14871230873837295931),
Felt::new(11225255908868362971),
Felt::new(18100987641405432308),
Felt::new(1559244340089644233),
Felt::new(12795429638314178838),
Felt::new(14360248269767567855),
Felt::new(3819563852436765058),
Felt::new(10859123583999067291),
],
[
Felt::new(8348203744950016968),
Felt::new(4041411241960726733),
Felt::new(17584743399305468057),
Felt::new(16836952610803537051),
Felt::new(2695742617679420093),
Felt::new(9151515850666059759),
Felt::new(15855828029180595485),
Felt::new(17190029785471463210),
],
[
Felt::new(16139797453633030050),
Felt::new(1090233424040889412),
Felt::new(10770255347785669036),
Felt::new(16982398877290254028),
Felt::new(13205273108219124830),
Felt::new(2524898486192849221),
Felt::new(14618764355375283547),
Felt::new(10615614265042186874),
],
];

View File

@@ -0,0 +1,634 @@
use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice};
use thiserror::Error;
use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO};
use crate::{
rand::Randomizable,
utils::{
bytes_to_hex_string, hex_to_bytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
},
};
// DIGEST TRAIT IMPLEMENTATIONS
// ================================================================================================
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct RpxDigest([Felt; DIGEST_SIZE]);
impl RpxDigest {
/// The serialized size of the digest in bytes.
pub const SERIALIZED_SIZE: usize = DIGEST_BYTES;
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
pub fn as_elements(&self) -> &[Felt] {
self.as_ref()
}
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
<Self as Digest>::as_bytes(self)
}
pub fn digests_as_elements_iter<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where
I: Iterator<Item = &'a Self>,
{
digests.flat_map(|d| d.0.iter())
}
pub fn digests_as_elements(digests: &[Self]) -> &[Felt] {
let p = digests.as_ptr();
let len = digests.len() * DIGEST_SIZE;
unsafe { slice::from_raw_parts(p as *const Felt, len) }
}
/// Returns hexadecimal representation of this digest prefixed with `0x`.
pub fn to_hex(&self) -> String {
bytes_to_hex_string(self.as_bytes())
}
}
impl Digest for RpxDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
result
}
}
impl Deref for RpxDigest {
type Target = [Felt; DIGEST_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Ord for RpxDigest {
fn cmp(&self, other: &Self) -> Ordering {
// compare the inner u64 of both elements.
//
// it will iterate the elements and will return the first computation different than
// `Equal`. Otherwise, the ordering is equal.
//
// the endianness is irrelevant here because since, this being a cryptographically secure
// hash computation, the digest shouldn't have any ordered property of its input.
//
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
// montgomery reduction for every limb. that is safe because every inner element of the
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
Ordering::Equal,
|ord, (a, b)| match ord {
Ordering::Equal => a.cmp(&b),
_ => ord,
},
)
}
}
impl PartialOrd for RpxDigest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Display for RpxDigest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let encoded: String = self.into();
write!(f, "{}", encoded)?;
Ok(())
}
}
impl Randomizable for RpxDigest {
const VALUE_SIZE: usize = DIGEST_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
if let Some(bytes_array) = bytes_array {
Self::try_from(bytes_array).ok()
} else {
None
}
}
}
// CONVERSIONS: FROM RPX DIGEST
// ================================================================================================
#[derive(Debug, Error)]
pub enum RpxDigestError {
#[error("failed to convert digest field element to {0}")]
TypeConversion(&'static str),
#[error("failed to convert to field element: {0}")]
InvalidFieldElement(String),
}
impl TryFrom<&RpxDigest> for [bool; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [bool; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
fn to_bool(v: u64) -> Option<bool> {
if v <= 1 {
Some(v == 1)
} else {
None
}
}
Ok([
to_bool(value.0[0].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[1].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[2].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
to_bool(value.0[3].as_int()).ok_or(RpxDigestError::TypeConversion("bool"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u8; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u8; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u8"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u16; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u16; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u16"))?,
])
}
}
impl TryFrom<&RpxDigest> for [u32; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: &RpxDigest) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<RpxDigest> for [u32; DIGEST_SIZE] {
type Error = RpxDigestError;
fn try_from(value: RpxDigest) -> Result<Self, Self::Error> {
Ok([
value.0[0]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[1]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[2]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
value.0[3]
.as_int()
.try_into()
.map_err(|_| RpxDigestError::TypeConversion("u32"))?,
])
}
}
impl From<&RpxDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpxDigest) -> Self {
(*value).into()
}
}
impl From<RpxDigest> for [u64; DIGEST_SIZE] {
fn from(value: RpxDigest) -> Self {
[
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
}
}
impl From<&RpxDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpxDigest) -> Self {
value.0
}
}
impl From<RpxDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpxDigest) -> Self {
value.0
}
}
impl From<&RpxDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpxDigest) -> Self {
value.as_bytes()
}
}
impl From<RpxDigest> for [u8; DIGEST_BYTES] {
fn from(value: RpxDigest) -> Self {
value.as_bytes()
}
}
impl From<&RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: &RpxDigest) -> Self {
(*value).into()
}
}
impl From<RpxDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpxDigest) -> Self {
value.to_hex()
}
}
// CONVERSIONS: TO RPX DIGEST
// ================================================================================================
impl From<&[bool; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[bool; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[bool; DIGEST_SIZE]> for RpxDigest {
fn from(value: [bool; DIGEST_SIZE]) -> Self {
[value[0] as u32, value[1] as u32, value[2] as u32, value[3] as u32].into()
}
}
impl From<&[u8; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u8; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u8; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u8; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u16; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u16; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u16; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u16; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl From<&[u32; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[u32; DIGEST_SIZE]) -> Self {
(*value).into()
}
}
impl From<[u32; DIGEST_SIZE]> for RpxDigest {
fn from(value: [u32; DIGEST_SIZE]) -> Self {
Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])
}
}
impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: &[u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
(*value).try_into()
}
}
impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest {
type Error = RpxDigestError;
fn try_from(value: [u64; DIGEST_SIZE]) -> Result<Self, RpxDigestError> {
Ok(Self([
value[0].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[1].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[2].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
value[3].try_into().map_err(RpxDigestError::InvalidFieldElement)?,
]))
}
}
impl From<&[Felt; DIGEST_SIZE]> for RpxDigest {
fn from(value: &[Felt; DIGEST_SIZE]) -> Self {
Self(*value)
}
}
impl From<[Felt; DIGEST_SIZE]> for RpxDigest {
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
}
impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: &[u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
// Note: the input length is known, the conversion from slice to array must succeed so the
// `unwrap`s below are safe
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
return Err(HexParseError::OutOfRange);
}
Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
}
}
impl TryFrom<&[u8]> for RpxDigest {
type Error = HexParseError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
(*value).try_into()
}
}
impl TryFrom<&str> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes::<DIGEST_BYTES>(value).and_then(RpxDigest::try_from)
}
}
impl TryFrom<&String> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl TryFrom<String> for RpxDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for RpxDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes());
}
fn get_size_hint(&self) -> usize {
Self::SERIALIZED_SIZE
}
}
impl Deserializable for RpxDigest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
for inner in inner.iter_mut() {
let e = source.read_u64()?;
if e >= Felt::MODULUS {
return Err(DeserializationError::InvalidValue(String::from(
"Value not in the appropriate range",
)));
}
*inner = Felt::new(e);
}
Ok(Self(inner))
}
}
// ITERATORS
// ================================================================================================
impl IntoIterator for RpxDigest {
type Item = Felt;
type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use alloc::string::String;
use rand_utils::rand_value;
use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE};
use crate::utils::SliceReader;
#[test]
fn digest_serialization() {
let e1 = Felt::new(rand_value());
let e2 = Felt::new(rand_value());
let e3 = Felt::new(rand_value());
let e4 = Felt::new(rand_value());
let d1 = RpxDigest([e1, e2, e3, e4]);
let mut bytes = vec![];
d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len());
assert_eq!(bytes.len(), d1.get_size_hint());
let mut reader = SliceReader::new(&bytes);
let d2 = RpxDigest::read_from(&mut reader).unwrap();
assert_eq!(d1, d2);
}
#[test]
fn digest_encoding() {
let digest = RpxDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
let string: String = digest.into();
let round_trip: RpxDigest = string.try_into().expect("decoding failed");
assert_eq!(digest, round_trip);
}
#[test]
fn test_conversions() {
let digest = RpxDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
// BY VALUE
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpxDigest = v.into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpxDigest = v.into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(v2).unwrap());
let v: [u64; DIGEST_SIZE] = digest.into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = digest.into();
let v2: RpxDigest = v.into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = digest.into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
let v: String = digest.into();
let v2: RpxDigest = v.try_into().unwrap();
assert_eq!(digest, v2);
// BY REF
// ----------------------------------------------------------------------------------------
let v: [bool; DIGEST_SIZE] = [true, false, true, true];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[bool; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u8; DIGEST_SIZE] = [0_u8, 1_u8, 2_u8, 3_u8];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u8; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u16; DIGEST_SIZE] = [0_u16, 1_u16, 2_u16, 3_u16];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u16; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u32; DIGEST_SIZE] = [0_u32, 1_u32, 2_u32, 3_u32];
let v2: RpxDigest = (&v).into();
assert_eq!(v, <[u32; DIGEST_SIZE]>::try_from(&v2).unwrap());
let v: [u64; DIGEST_SIZE] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: [Felt; DIGEST_SIZE] = (&digest).into();
let v2: RpxDigest = (&v).into();
assert_eq!(digest, v2);
let v: [u8; DIGEST_BYTES] = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
let v: String = (&digest).into();
let v2: RpxDigest = (&v).try_into().unwrap();
assert_eq!(digest, v2);
}
}

385
src/hash/rescue/rpx/mod.rs Normal file
View File

@@ -0,0 +1,385 @@
use core::ops::Range;
use super::{
add_constants, add_constants_and_apply_inv_sbox, add_constants_and_apply_sbox, apply_inv_sbox,
apply_mds, apply_sbox, CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher,
StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE,
DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, NUM_ROUNDS, RATE_RANGE, RATE_WIDTH, STATE_WIDTH,
ZERO,
};
mod digest;
pub use digest::{RpxDigest, RpxDigestError};
#[cfg(test)]
mod tests;
pub type CubicExtElement = CubeExtension<Felt>;
// HASHER IMPLEMENTATION
// ================================================================================================
/// Implementation of the Rescue Prime eXtension hash function with 256-bit output.
///
/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045)
///
/// The parameters used to instantiate the function are:
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
/// * State width: 12 field elements.
/// * Capacity size: 4 field elements.
/// * S-Box degree: 7.
/// * Rounds: There are 3 different types of rounds:
/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` →
/// `apply_inv_sbox`.
/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension
/// field).
/// - (M): `apply_mds` → `add_constants`.
/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M).
///
/// The above parameters target a 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and
/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing
/// a hash for the same set of elements using these functions will always produce the same
/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the
/// same result as hashing 8 elements which make up these digests using
/// [hash_elements()](Rpx256::hash_elements) function.
///
/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above.
/// For example, if we take two field elements, serialize them to bytes and hash them using
/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these
/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for
/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
/// deserialization procedure used by this function is different from the procedure used to
/// deserialize valid field elements.
///
/// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpx256::hash_elements) function rather than hashing the serialized bytes
/// using [hash()](Rpx256::hash) function.
///
/// ## Domain separation
/// [merge_in_domain()](Rpx256::merge_in_domain) hashes two digests into one digest with some domain
/// identifier and the current implementation sets the second capacity element to the value of
/// this domain identifier. Using a similar argument to the one formulated for domain separation
/// in Appendix C of the [specifications](https://eprint.iacr.org/2023/1045), one sees that doing
/// so degrades only pre-image resistance, from its initial bound of c.log_2(p), by as much as
/// the log_2 of the size of the domain identifier space. Since pre-image resistance becomes
/// the bottleneck for the security bound of the sponge in overwrite-mode only when it is
/// lower than 2^128, we see that the target 128-bit security level is maintained as long as
/// the size of the domain identifier space, including for padding, is less than 2^128.
///
/// ## Hashing of empty input
/// The current implementation hashes empty input to the zero digest [0, 0, 0, 0]. This has
/// the benefit of requiring no calls to the RPX permutation when hashing empty input.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpx256();
impl Hasher for Rpx256 {
/// Rpx256 collision resistance is 128-bits.
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpxDigest;
fn hash(bytes: &[u8]) -> Self::Digest {
// initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH];
// determine the number of field elements needed to encode `bytes` when each field element
// represents at most 7 bytes.
let num_field_elem = bytes.len().div_ceil(BINARY_CHUNK_SIZE);
// set the first capacity element to `RATE_WIDTH + (num_field_elem % RATE_WIDTH)`. We do
// this to achieve:
// 1. Domain separating hashing of `[u8]` from hashing of `[Felt]`.
// 2. Avoiding collisions at the `[Felt]` representation of the encoded bytes.
state[CAPACITY_RANGE.start] =
Felt::from((RATE_WIDTH + (num_field_elem % RATE_WIDTH)) as u8);
// initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8];
// iterate the chunks of bytes, creating a field element from each chunk and copying it
// into the state.
//
// every time the rate range is filled, a permutation is performed. if the final value of
// `rate_pos` is not zero, then the chunks count wasn't enough to fill the state range,
// and an additional permutation must be performed.
let mut current_chunk_idx = 0_usize;
// handle the case of an empty `bytes`
let last_chunk_idx = if num_field_elem == 0 {
current_chunk_idx
} else {
num_field_elem - 1
};
let rate_pos = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |rate_pos, chunk| {
// copy the chunk into the buffer
if current_chunk_idx != last_chunk_idx {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else {
// on the last iteration, we pad `buf` with a 1 followed by as many 0's as are
// needed to fill it
buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1;
}
current_chunk_idx += 1;
// set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + rate_pos] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range.
if rate_pos == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state);
0
} else {
rate_pos + 1
}
});
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we
// don't need to apply any extra padding because the first capacity element contains a
// flag indicating the number of field elements constituting the last block when the latter
// is not divisible by `RATE_WIDTH`.
if rate_pos != 0 {
state[RATE_RANGE.start + rate_pos..RATE_RANGE.end].fill(ZERO);
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the rate as hash result.
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// apply the RPX permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge_many(values: &[Self::Digest]) -> Self::Digest {
Self::hash_elements(Self::Digest::digests_as_elements(values))
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element and
// set the first capacity element to 5.
// - if the value doesn't fit into a single field element, split it into two field elements,
// copy them into rate elements 5 and 6 and set the first capacity element to 6.
let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS {
state[CAPACITY_RANGE.start] = Felt::from(5_u8);
} else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[CAPACITY_RANGE.start] = Felt::from(6_u8);
}
// apply the RPX permutation and return the first four elements of the rate
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
impl ElementHasher for Rpx256 {
type BaseField = Felt;
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
// convert the elements into a list of base field elements
let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which
// is set to `elements.len() % RATE_WIDTH`.
let mut state = [ZERO; STATE_WIDTH];
state[CAPACITY_RANGE.start] = Self::BaseField::from((elements.len() % RATE_WIDTH) as u8);
// absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all
// elements have been absorbed
let mut i = 0;
for &element in elements.iter() {
state[RATE_RANGE.start + i] = element;
i += 1;
if i % RATE_WIDTH == 0 {
Self::apply_permutation(&mut state);
i = 0;
}
}
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after
// padding by as many 0 as necessary to make the input length a multiple of the RATE_WIDTH.
if i > 0 {
while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO;
i += 1;
}
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the state as hash result
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
// HASH FUNCTION IMPLEMENTATION
// ================================================================================================
impl Rpx256 {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
pub const STATE_WIDTH: usize = STATE_WIDTH;
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
/// MDS matrix used for computing the linear layer in the (FB) and (E) rounds.
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
/// Round constants added to the hasher state in the first half of the round.
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
/// Round constants added to the hasher state in the second half of the round.
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
// TRAIT PASS-THROUGH FUNCTIONS
// --------------------------------------------------------------------------------------------
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> RpxDigest {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpxDigest {
<Self as ElementHasher>::hash_elements(elements)
}
// DOMAIN IDENTIFIER
// --------------------------------------------------------------------------------------------
/// Returns a hash of two digests and a domain identifier.
pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = RpxDigest::digests_as_elements_iter(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// set the second capacity element to the domain value. The first capacity element is used
// for padding purposes.
state[CAPACITY_RANGE.start + 1] = domain;
// apply the RPX permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
// RPX PERMUTATION
// --------------------------------------------------------------------------------------------
/// Applies RPX permutation to the provided state.
#[inline(always)]
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
Self::apply_fb_round(state, 0);
Self::apply_ext_round(state, 1);
Self::apply_fb_round(state, 2);
Self::apply_ext_round(state, 3);
Self::apply_fb_round(state, 4);
Self::apply_ext_round(state, 5);
Self::apply_final_round(state, 6);
}
// RPX PERMUTATION ROUND FUNCTIONS
// --------------------------------------------------------------------------------------------
/// (FB) round function.
#[inline(always)]
pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
apply_mds(state);
if !add_constants_and_apply_sbox(state, &ARK1[round]) {
add_constants(state, &ARK1[round]);
apply_sbox(state);
}
apply_mds(state);
if !add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
add_constants(state, &ARK2[round]);
apply_inv_sbox(state);
}
}
/// (E) round function.
#[inline(always)]
pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
// add constants
add_constants(state, &ARK1[round]);
// decompose the state into 4 elements in the cubic extension field and apply the power 7
// map to each of the elements
let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state;
let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2));
let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5));
let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8));
let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11));
// decompose the state back into 12 base field elements
let arr_ext = [ext0, ext1, ext2, ext3];
*state = CubicExtElement::slice_as_base_elements(&arr_ext)
.try_into()
.expect("shouldn't fail");
}
/// (M) round function.
#[inline(always)]
pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
apply_mds(state);
add_constants(state, &ARK1[round]);
}
/// Computes an exponentiation to the power 7 in cubic extension field.
#[inline(always)]
pub fn exp7(x: CubeExtension<Felt>) -> CubeExtension<Felt> {
let x2 = x.square();
let x4 = x2.square();
let x3 = x2 * x;
x3 * x4
}
}

View File

@@ -0,0 +1,186 @@
use alloc::{collections::BTreeSet, vec::Vec};
use proptest::prelude::*;
use rand_utils::rand_value;
use super::{Felt, Hasher, Rpx256, StarkField, ZERO};
use crate::{hash::rescue::RpxDigest, ONE};
#[test]
fn hash_elements_vs_merge() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..].try_into().unwrap()),
];
let m_result = Rpx256::merge(&digests);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn merge_vs_merge_in_domain() {
let elements = [Felt::new(rand_value()); 8];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..].try_into().unwrap()),
];
let merge_result = Rpx256::merge(&digests);
// ----- merge with domain = 0 ----------------------------------------------------------------
// set domain to ZERO. This should not change the result.
let domain = ZERO;
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
assert_eq!(merge_result, merge_in_domain_result);
// ----- merge with domain = 1 ----------------------------------------------------------------
// set domain to ONE. This should change the result.
let domain = ONE;
let merge_in_domain_result = Rpx256::merge_in_domain(&digests, domain);
assert_ne!(merge_result, merge_in_domain_result);
}
#[test]
fn hash_elements_vs_merge_with_int() {
let tmp = [Felt::new(rand_value()); 4];
let seed = RpxDigest::new(tmp);
// ----- value fits into a field element ------------------------------------------------------
let val: Felt = Felt::new(rand_value());
let m_result = Rpx256::merge_with_int(seed, val.as_int());
let mut elements = seed.as_elements().to_vec();
elements.push(val);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
// ----- value does not fit into a field element ----------------------------------------------
let val = Felt::MODULUS + 2;
let m_result = Rpx256::merge_with_int(seed, val);
let mut elements = seed.as_elements().to_vec();
elements.push(Felt::new(val));
elements.push(ONE);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_padding() {
// adding a zero bytes at the end of a byte string should result in a different hash
let r1 = Rpx256::hash(&[1_u8, 2, 3]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 0]);
assert_ne!(r1, r2);
// same as above but with bigger inputs
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 0]);
assert_ne!(r1, r2);
// same as above but with input splitting over two elements
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0]);
assert_ne!(r1, r2);
// same as above but with multiple zeros
let r1 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0]);
let r2 = Rpx256::hash(&[1_u8, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0]);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements_padding() {
let e1 = [Felt::new(rand_value()); 2];
let e2 = [e1[0], e1[1], ZERO];
let r1 = Rpx256::hash_elements(&e1);
let r2 = Rpx256::hash_elements(&e2);
assert_ne!(r1, r2);
}
#[test]
fn hash_elements() {
let elements = [
ZERO,
ONE,
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::new(7),
];
let digests: [RpxDigest; 2] = [
RpxDigest::new(elements[..4].try_into().unwrap()),
RpxDigest::new(elements[4..8].try_into().unwrap()),
];
let m_result = Rpx256::merge(&digests);
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(m_result, h_result);
}
#[test]
fn hash_empty() {
let elements: Vec<Felt> = vec![];
let zero_digest = RpxDigest::default();
let h_result = Rpx256::hash_elements(&elements);
assert_eq!(zero_digest, h_result);
}
#[test]
fn hash_empty_bytes() {
let bytes: Vec<u8> = vec![];
let zero_digest = RpxDigest::default();
let h_result = Rpx256::hash(&bytes);
assert_eq!(zero_digest, h_result);
}
#[test]
fn sponge_bytes_with_remainder_length_wont_panic() {
// this test targets to assert that no panic will happen with the edge case of having an inputs
// with length that is not divisible by the used binary chunk size. 113 is a non-negligible
// input length that is prime; hence guaranteed to not be divisible by any choice of chunk
// size.
//
// this is a preliminary test to the fuzzy-stress of proptest.
Rpx256::hash(&[0; 113]);
}
#[test]
fn sponge_collision_for_wrapped_field_element() {
let a = Rpx256::hash(&[0; 8]);
let b = Rpx256::hash(&Felt::MODULUS.to_le_bytes());
assert_ne!(a, b);
}
#[test]
fn sponge_zeroes_collision() {
let mut zeroes = Vec::with_capacity(255);
let mut set = BTreeSet::new();
(0..255).for_each(|_| {
let hash = Rpx256::hash(&zeroes);
zeroes.push(0);
// panic if a collision was found
assert!(set.insert(hash));
});
}
proptest! {
#[test]
fn rpo256_wont_panic_with_arbitrary_input(ref bytes in any::<Vec<u8>>()) {
Rpx256::hash(bytes);
}
}

10
src/hash/rescue/tests.rs Normal file
View File

@@ -0,0 +1,10 @@
use rand_utils::rand_value;
use super::{Felt, FieldElement, ALPHA, INV_ALPHA};
#[test]
fn test_alphas() {
let e: Felt = Felt::new(rand_value());
let e_exp = e.exp(ALPHA);
assert_eq!(e, e_exp.exp(INV_ALPHA));
}

View File

@@ -1,299 +0,0 @@
use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO};
use crate::utils::{
bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable,
DeserializationError, HexParseError, Serializable,
};
use core::{cmp::Ordering, fmt::Display, ops::Deref};
use winter_utils::Randomizable;
/// The number of bytes needed to encoded a digest
pub const DIGEST_BYTES: usize = 32;
// DIGEST TRAIT IMPLEMENTATIONS
// ================================================================================================
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))]
pub struct RpoDigest([Felt; DIGEST_SIZE]);
impl RpoDigest {
pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
pub fn as_elements(&self) -> &[Felt] {
self.as_ref()
}
pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
<Self as Digest>::as_bytes(self)
}
pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator<Item = &'a Felt>
where
I: Iterator<Item = &'a Self>,
{
digests.flat_map(|d| d.0.iter())
}
}
impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES];
result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes());
result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes());
result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes());
result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes());
result
}
}
impl Deref for RpoDigest {
type Target = [Felt; DIGEST_SIZE];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Ord for RpoDigest {
fn cmp(&self, other: &Self) -> Ordering {
// compare the inner u64 of both elements.
//
// it will iterate the elements and will return the first computation different than
// `Equal`. Otherwise, the ordering is equal.
//
// the endianness is irrelevant here because since, this being a cryptographically secure
// hash computation, the digest shouldn't have any ordered property of its input.
//
// finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a
// montgomery reduction for every limb. that is safe because every inner element of the
// digest is guaranteed to be in its canonical form (that is, `x in [0,p)`).
self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold(
Ordering::Equal,
|ord, (a, b)| match ord {
Ordering::Equal => a.cmp(&b),
_ => ord,
},
)
}
}
impl PartialOrd for RpoDigest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Display for RpoDigest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let encoded: String = self.into();
write!(f, "{}", encoded)?;
Ok(())
}
}
impl Randomizable for RpoDigest {
const VALUE_SIZE: usize = DIGEST_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
let bytes_array: Option<[u8; 32]> = bytes.try_into().ok();
if let Some(bytes_array) = bytes_array {
Self::try_from(bytes_array).ok()
} else {
None
}
}
}
// CONVERSIONS: FROM RPO DIGEST
// ================================================================================================
impl From<&RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
value.0
}
}
impl From<RpoDigest> for [Felt; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
value.0
}
}
impl From<&RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: &RpoDigest) -> Self {
[
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
}
}
impl From<RpoDigest> for [u64; DIGEST_SIZE] {
fn from(value: RpoDigest) -> Self {
[
value.0[0].as_int(),
value.0[1].as_int(),
value.0[2].as_int(),
value.0[3].as_int(),
]
}
}
impl From<&RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: &RpoDigest) -> Self {
value.as_bytes()
}
}
impl From<RpoDigest> for [u8; DIGEST_BYTES] {
fn from(value: RpoDigest) -> Self {
value.as_bytes()
}
}
impl From<RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: RpoDigest) -> Self {
bytes_to_hex_string(value.as_bytes())
}
}
impl From<&RpoDigest> for String {
/// The returned string starts with `0x`.
fn from(value: &RpoDigest) -> Self {
(*value).into()
}
}
// CONVERSIONS: TO DIGEST
// ================================================================================================
impl From<[Felt; DIGEST_SIZE]> for RpoDigest {
fn from(value: [Felt; DIGEST_SIZE]) -> Self {
Self(value)
}
}
impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest {
type Error = HexParseError;
fn try_from(value: [u8; DIGEST_BYTES]) -> Result<Self, Self::Error> {
// Note: the input length is known, the conversion from slice to array must succeed so the
// `unwrap`s below are safe
let a = u64::from_le_bytes(value[0..8].try_into().unwrap());
let b = u64::from_le_bytes(value[8..16].try_into().unwrap());
let c = u64::from_le_bytes(value[16..24].try_into().unwrap());
let d = u64::from_le_bytes(value[24..32].try_into().unwrap());
if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) {
return Err(HexParseError::OutOfRange);
}
Ok(RpoDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)]))
}
}
impl TryFrom<&str> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &str) -> Result<Self, Self::Error> {
hex_to_bytes(value).and_then(|v| v.try_into())
}
}
impl TryFrom<String> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
impl TryFrom<&String> for RpoDigest {
type Error = HexParseError;
/// Expects the string to start with `0x`.
fn try_from(value: &String) -> Result<Self, Self::Error> {
value.as_str().try_into()
}
}
// SERIALIZATION / DESERIALIZATION
// ================================================================================================
impl Serializable for RpoDigest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_bytes(&self.as_bytes());
}
}
impl Deserializable for RpoDigest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE];
for inner in inner.iter_mut() {
let e = source.read_u64()?;
if e >= Felt::MODULUS {
return Err(DeserializationError::InvalidValue(String::from(
"Value not in the appropriate range",
)));
}
*inner = Felt::new(e);
}
Ok(Self(inner))
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES};
use crate::utils::SliceReader;
use rand_utils::rand_value;
#[test]
fn digest_serialization() {
let e1 = Felt::new(rand_value());
let e2 = Felt::new(rand_value());
let e3 = Felt::new(rand_value());
let e4 = Felt::new(rand_value());
let d1 = RpoDigest([e1, e2, e3, e4]);
let mut bytes = vec![];
d1.write_into(&mut bytes);
assert_eq!(DIGEST_BYTES, bytes.len());
let mut reader = SliceReader::new(&bytes);
let d2 = RpoDigest::read_from(&mut reader).unwrap();
assert_eq!(d1, d2);
}
#[cfg(feature = "std")]
#[test]
fn digest_encoding() {
let digest = RpoDigest([
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
Felt::new(rand_value()),
]);
let string: String = digest.into();
let round_trip: RpoDigest = string.try_into().expect("decoding failed");
assert_eq!(digest, round_trip);
}
}

View File

@@ -1,905 +0,0 @@
use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO};
use core::{convert::TryInto, ops::Range};
mod digest;
pub use digest::RpoDigest;
mod mds_freq;
use mds_freq::mds_multiply_freq;
#[cfg(test)]
mod tests;
#[cfg(all(target_feature = "sve", feature = "sve"))]
#[link(name = "rpo_sve", kind = "static")]
extern "C" {
fn add_constants_and_apply_sbox(
state: *mut std::ffi::c_ulong,
constants: *const std::ffi::c_ulong,
) -> bool;
fn add_constants_and_apply_inv_sbox(
state: *mut std::ffi::c_ulong,
constants: *const std::ffi::c_ulong,
) -> bool;
}
// CONSTANTS
// ================================================================================================
/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
const STATE_WIDTH: usize = 12;
/// The rate portion of the state is located in elements 4 through 11.
const RATE_RANGE: Range<usize> = 4..12;
const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start;
const INPUT1_RANGE: Range<usize> = 4..8;
const INPUT2_RANGE: Range<usize> = 8..12;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
const CAPACITY_RANGE: Range<usize> = 0..4;
/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes.
///
/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the
/// rate portion).
const DIGEST_RANGE: Range<usize> = 4..8;
const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start;
/// The number of rounds is set to 7 to target 128-bit security level
const NUM_ROUNDS: usize = 7;
/// The number of byte chunks defining a field element when hashing a sequence of bytes
const BINARY_CHUNK_SIZE: usize = 7;
/// S-Box and Inverse S-Box powers;
///
/// The constants are defined for tests only because the exponentiations in the code are unrolled
/// for efficiency reasons.
#[cfg(test)]
const ALPHA: u64 = 7;
#[cfg(test)]
const INV_ALPHA: u64 = 10540996611094048183;
// HASHER IMPLEMENTATION
// ================================================================================================
/// Implementation of the Rescue Prime Optimized hash function with 256-bit output.
///
/// The hash function is implemented according to the Rescue Prime Optimized
/// [specifications](https://eprint.iacr.org/2022/1577)
///
/// The parameters used to instantiate the function are:
/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1.
/// * State width: 12 field elements.
/// * Capacity size: 4 field elements.
/// * Number of founds: 7.
/// * S-Box degree: 7.
///
/// The above parameters target 128-bit security level. The digest consists of four field elements
/// and it can be serialized into 32 bytes (256 bits).
///
/// ## Hash output consistency
/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and
/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing
/// a hash for the same set of elements using these functions will always produce the same
/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the
/// same result as hashing 8 elements which make up these digests using
/// [hash_elements()](Rpo256::hash_elements) function.
///
/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above.
/// For example, if we take two field elements, serialize them to bytes and hash them using
/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these
/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for
/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle
/// arbitrary binary strings, which may or may not encode valid field elements - and thus,
/// deserialization procedure used by this function is different from the procedure used to
/// deserialize valid field elements.
///
/// Thus, if the underlying data consists of valid field elements, it might make more sense
/// to deserialize them into field elements and then hash them using
/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes
/// using [hash()](Rpo256::hash) function.
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Rpo256();
impl Hasher for Rpo256 {
/// Rpo256 collision resistance is the same as the security level, that is 128-bits.
///
/// #### Collision resistance
///
/// However, our setup of the capacity registers might drop it to 126.
///
/// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69)
const COLLISION_RESISTANCE: u32 = 128;
type Digest = RpoDigest;
fn hash(bytes: &[u8]) -> Self::Digest {
// initialize the state with zeroes
let mut state = [ZERO; STATE_WIDTH];
// set the capacity (first element) to a flag on whether or not the input length is evenly
// divided by the rate. this will prevent collisions between padded and non-padded inputs,
// and will rule out the need to perform an extra permutation in case of evenly divided
// inputs.
let is_rate_multiple = bytes.len() % RATE_WIDTH == 0;
if !is_rate_multiple {
state[CAPACITY_RANGE.start] = ONE;
}
// initialize a buffer to receive the little-endian elements.
let mut buf = [0_u8; 8];
// iterate the chunks of bytes, creating a field element from each chunk and copying it
// into the state.
//
// every time the rate range is filled, a permutation is performed. if the final value of
// `i` is not zero, then the chunks count wasn't enough to fill the state range, and an
// additional permutation must be performed.
let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| {
// the last element of the iteration may or may not be a full chunk. if it's not, then
// we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`.
// this will avoid collisions.
if chunk.len() == BINARY_CHUNK_SIZE {
buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk);
} else {
buf.fill(0);
buf[..chunk.len()].copy_from_slice(chunk);
buf[chunk.len()] = 1;
}
// set the current rate element to the input. since we take at most 7 bytes, we are
// guaranteed that the inputs data will fit into a single field element.
state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf));
// proceed filling the range. if it's full, then we apply a permutation and reset the
// counter to the beginning of the range.
if i == RATE_WIDTH - 1 {
Self::apply_permutation(&mut state);
0
} else {
i + 1
}
});
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we
// don't need to apply any extra padding because the first capacity element containts a
// flag indicating whether the input is evenly divisible by the rate.
if i != 0 {
state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO);
state[RATE_RANGE.start + i] = ONE;
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the rate as hash result.
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge(values: &[Self::Digest; 2]) -> Self::Digest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = Self::Digest::digests_as_elements(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest {
// initialize the state as follows:
// - seed is copied into the first 4 elements of the rate portion of the state.
// - if the value fits into a single field element, copy it into the fifth rate element
// and set the sixth rate element to 1.
// - if the value doesn't fit into a single field element, split it into two field
// elements, copy them into rate elements 5 and 6, and set the seventh rate element
// to 1.
// - set the first capacity element to 1
let mut state = [ZERO; STATE_WIDTH];
state[INPUT1_RANGE].copy_from_slice(seed.as_elements());
state[INPUT2_RANGE.start] = Felt::new(value);
if value < Felt::MODULUS {
state[INPUT2_RANGE.start + 1] = ONE;
} else {
state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS);
state[INPUT2_RANGE.start + 2] = ONE;
}
// common padding for both cases
state[CAPACITY_RANGE.start] = ONE;
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
impl ElementHasher for Rpo256 {
type BaseField = Felt;
fn hash_elements<E: FieldElement<BaseField = Self::BaseField>>(elements: &[E]) -> Self::Digest {
// convert the elements into a list of base field elements
let elements = E::slice_as_base_elements(elements);
// initialize state to all zeros, except for the first element of the capacity part, which
// is set to 1 if the number of elements is not a multiple of RATE_WIDTH.
let mut state = [ZERO; STATE_WIDTH];
if elements.len() % RATE_WIDTH != 0 {
state[CAPACITY_RANGE.start] = ONE;
}
// absorb elements into the state one by one until the rate portion of the state is filled
// up; then apply the Rescue permutation and start absorbing again; repeat until all
// elements have been absorbed
let mut i = 0;
for &element in elements.iter() {
state[RATE_RANGE.start + i] = element;
i += 1;
if i % RATE_WIDTH == 0 {
Self::apply_permutation(&mut state);
i = 0;
}
}
// if we absorbed some elements but didn't apply a permutation to them (would happen when
// the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after
// padding by appending a 1 followed by as many 0 as necessary to make the input length a
// multiple of the RATE_WIDTH.
if i > 0 {
state[RATE_RANGE.start + i] = ONE;
i += 1;
while i != RATE_WIDTH {
state[RATE_RANGE.start + i] = ZERO;
i += 1;
}
Self::apply_permutation(&mut state);
}
// return the first 4 elements of the state as hash result
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
}
// HASH FUNCTION IMPLEMENTATION
// ================================================================================================
impl Rpo256 {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// The number of rounds is set to 7 to target 128-bit security level.
pub const NUM_ROUNDS: usize = NUM_ROUNDS;
/// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and
/// the remaining 4 elements are reserved for capacity.
pub const STATE_WIDTH: usize = STATE_WIDTH;
/// The rate portion of the state is located in elements 4 through 11 (inclusive).
pub const RATE_RANGE: Range<usize> = RATE_RANGE;
/// The capacity portion of the state is located in elements 0, 1, 2, and 3.
pub const CAPACITY_RANGE: Range<usize> = CAPACITY_RANGE;
/// The output of the hash function can be read from state elements 4, 5, 6, and 7.
pub const DIGEST_RANGE: Range<usize> = DIGEST_RANGE;
/// MDS matrix used for computing the linear layer in a RPO round.
pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS;
/// Round constants added to the hasher state in the first half of the RPO round.
pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1;
/// Round constants added to the hasher state in the second half of the RPO round.
pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2;
// TRAIT PASS-THROUGH FUNCTIONS
// --------------------------------------------------------------------------------------------
/// Returns a hash of the provided sequence of bytes.
#[inline(always)]
pub fn hash(bytes: &[u8]) -> RpoDigest {
<Self as Hasher>::hash(bytes)
}
/// Returns a hash of two digests. This method is intended for use in construction of
/// Merkle trees and verification of Merkle paths.
#[inline(always)]
pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest {
<Self as Hasher>::merge(values)
}
/// Returns a hash of the provided field elements.
#[inline(always)]
pub fn hash_elements<E: FieldElement<BaseField = Felt>>(elements: &[E]) -> RpoDigest {
<Self as ElementHasher>::hash_elements(elements)
}
// DOMAIN IDENTIFIER
// --------------------------------------------------------------------------------------------
/// Returns a hash of two digests and a domain identifier.
pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest {
// initialize the state by copying the digest elements into the rate portion of the state
// (8 total elements), and set the capacity elements to 0.
let mut state = [ZERO; STATE_WIDTH];
let it = RpoDigest::digests_as_elements(values.iter());
for (i, v) in it.enumerate() {
state[RATE_RANGE.start + i] = *v;
}
// set the second capacity element to the domain value. The first capacity element is used
// for padding purposes.
state[CAPACITY_RANGE.start + 1] = domain;
// apply the RPO permutation and return the first four elements of the state
Self::apply_permutation(&mut state);
RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap())
}
// RESCUE PERMUTATION
// --------------------------------------------------------------------------------------------
/// Applies RPO permutation to the provided state.
#[inline(always)]
pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) {
for i in 0..NUM_ROUNDS {
Self::apply_round(state, i);
}
}
/// RPO round function.
#[inline(always)]
pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) {
// apply first half of RPO round
Self::apply_mds(state);
if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) {
Self::add_constants(state, &ARK1[round]);
Self::apply_sbox(state);
}
// apply second half of RPO round
Self::apply_mds(state);
if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) {
Self::add_constants(state, &ARK2[round]);
Self::apply_inv_sbox(state);
}
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
#[inline(always)]
#[cfg(all(target_feature = "sve", feature = "sve"))]
fn optimized_add_constants_and_apply_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
unsafe {
add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64)
}
}
#[inline(always)]
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
fn optimized_add_constants_and_apply_sbox(
_state: &mut [Felt; STATE_WIDTH],
_ark: &[Felt; STATE_WIDTH],
) -> bool {
false
}
#[inline(always)]
#[cfg(all(target_feature = "sve", feature = "sve"))]
fn optimized_add_constants_and_apply_inv_sbox(
state: &mut [Felt; STATE_WIDTH],
ark: &[Felt; STATE_WIDTH],
) -> bool {
unsafe {
add_constants_and_apply_inv_sbox(
state.as_mut_ptr() as *mut u64,
ark.as_ptr() as *const u64,
)
}
}
#[inline(always)]
#[cfg(not(all(target_feature = "sve", feature = "sve")))]
fn optimized_add_constants_and_apply_inv_sbox(
_state: &mut [Felt; STATE_WIDTH],
_ark: &[Felt; STATE_WIDTH],
) -> bool {
false
}
#[inline(always)]
fn apply_mds(state: &mut [Felt; STATE_WIDTH]) {
let mut result = [ZERO; STATE_WIDTH];
// Using the linearity of the operations we can split the state into a low||high decomposition
// and operate on each with no overflow and then combine/reduce the result to a field element.
// The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in
// frequency domain.
let mut state_l = [0u64; STATE_WIDTH];
let mut state_h = [0u64; STATE_WIDTH];
for r in 0..STATE_WIDTH {
let s = state[r].inner();
state_h[r] = s >> 32;
state_l[r] = (s as u32) as u64;
}
let state_h = mds_multiply_freq(state_h);
let state_l = mds_multiply_freq(state_l);
for r in 0..STATE_WIDTH {
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
let s_hi = (s >> 64) as u64;
let s_lo = s as u64;
let z = (s_hi << 32) - s_hi;
let (res, over) = s_lo.overflowing_add(z);
result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
}
*state = result;
}
#[inline(always)]
fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) {
state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k);
}
#[inline(always)]
fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) {
state[0] = state[0].exp7();
state[1] = state[1].exp7();
state[2] = state[2].exp7();
state[3] = state[3].exp7();
state[4] = state[4].exp7();
state[5] = state[5].exp7();
state[6] = state[6].exp7();
state[7] = state[7].exp7();
state[8] = state[8].exp7();
state[9] = state[9].exp7();
state[10] = state[10].exp7();
state[11] = state[11].exp7();
}
#[inline(always)]
fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) {
// compute base^10540996611094048183 using 72 multiplications per array element
// 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111
// compute base^10
let mut t1 = *state;
t1.iter_mut().for_each(|t| *t = t.square());
// compute base^100
let mut t2 = t1;
t2.iter_mut().for_each(|t| *t = t.square());
// compute base^100100
let t3 = Self::exp_acc::<Felt, STATE_WIDTH, 3>(t2, t2);
// compute base^100100100100
let t4 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t3, t3);
// compute base^100100100100100100100100
let t5 = Self::exp_acc::<Felt, STATE_WIDTH, 12>(t4, t4);
// compute base^100100100100100100100100100100
let t6 = Self::exp_acc::<Felt, STATE_WIDTH, 6>(t5, t3);
// compute base^1001001001001001001001001001000100100100100100100100100100100
let t7 = Self::exp_acc::<Felt, STATE_WIDTH, 31>(t6, t6);
// compute base^1001001001001001001001001001000110110110110110110110110110110111
for (i, s) in state.iter_mut().enumerate() {
let a = (t7[i].square() * t6[i]).square().square();
let b = t1[i] * t2[i] * *s;
*s = a * b;
}
}
#[inline(always)]
fn exp_acc<B: StarkField, const N: usize, const M: usize>(
base: [B; N],
tail: [B; N],
) -> [B; N] {
let mut result = base;
for _ in 0..M {
result.iter_mut().for_each(|r| *r = r.square());
}
result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t);
result
}
}
// MDS
// ================================================================================================
/// RPO MDS matrix
const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [
[
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
],
[
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
],
[
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
],
[
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
],
[
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
],
[
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
],
[
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
],
[
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
],
[
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
Felt::new(26),
],
[
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
Felt::new(8),
],
[
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
Felt::new(23),
],
[
Felt::new(23),
Felt::new(8),
Felt::new(26),
Felt::new(13),
Felt::new(10),
Felt::new(9),
Felt::new(7),
Felt::new(6),
Felt::new(22),
Felt::new(21),
Felt::new(8),
Felt::new(7),
],
];
// ROUND CONSTANTS
// ================================================================================================
/// Rescue round constants;
/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo)
///
/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the
/// first half of RPO round, and ARK2 contains constants for the second half of RPO round.
const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[
Felt::new(5789762306288267392),
Felt::new(6522564764413701783),
Felt::new(17809893479458208203),
Felt::new(107145243989736508),
Felt::new(6388978042437517382),
Felt::new(15844067734406016715),
Felt::new(9975000513555218239),
Felt::new(3344984123768313364),
Felt::new(9959189626657347191),
Felt::new(12960773468763563665),
Felt::new(9602914297752488475),
Felt::new(16657542370200465908),
],
[
Felt::new(12987190162843096997),
Felt::new(653957632802705281),
Felt::new(4441654670647621225),
Felt::new(4038207883745915761),
Felt::new(5613464648874830118),
Felt::new(13222989726778338773),
Felt::new(3037761201230264149),
Felt::new(16683759727265180203),
Felt::new(8337364536491240715),
Felt::new(3227397518293416448),
Felt::new(8110510111539674682),
Felt::new(2872078294163232137),
],
[
Felt::new(18072785500942327487),
Felt::new(6200974112677013481),
Felt::new(17682092219085884187),
Felt::new(10599526828986756440),
Felt::new(975003873302957338),
Felt::new(8264241093196931281),
Felt::new(10065763900435475170),
Felt::new(2181131744534710197),
Felt::new(6317303992309418647),
Felt::new(1401440938888741532),
Felt::new(8884468225181997494),
Felt::new(13066900325715521532),
],
[
Felt::new(5674685213610121970),
Felt::new(5759084860419474071),
Felt::new(13943282657648897737),
Felt::new(1352748651966375394),
Felt::new(17110913224029905221),
Felt::new(1003883795902368422),
Felt::new(4141870621881018291),
Felt::new(8121410972417424656),
Felt::new(14300518605864919529),
Felt::new(13712227150607670181),
Felt::new(17021852944633065291),
Felt::new(6252096473787587650),
],
[
Felt::new(4887609836208846458),
Felt::new(3027115137917284492),
Felt::new(9595098600469470675),
Felt::new(10528569829048484079),
Felt::new(7864689113198939815),
Felt::new(17533723827845969040),
Felt::new(5781638039037710951),
Felt::new(17024078752430719006),
Felt::new(109659393484013511),
Felt::new(7158933660534805869),
Felt::new(2955076958026921730),
Felt::new(7433723648458773977),
],
[
Felt::new(16308865189192447297),
Felt::new(11977192855656444890),
Felt::new(12532242556065780287),
Felt::new(14594890931430968898),
Felt::new(7291784239689209784),
Felt::new(5514718540551361949),
Felt::new(10025733853830934803),
Felt::new(7293794580341021693),
Felt::new(6728552937464861756),
Felt::new(6332385040983343262),
Felt::new(13277683694236792804),
Felt::new(2600778905124452676),
],
[
Felt::new(7123075680859040534),
Felt::new(1034205548717903090),
Felt::new(7717824418247931797),
Felt::new(3019070937878604058),
Felt::new(11403792746066867460),
Felt::new(10280580802233112374),
Felt::new(337153209462421218),
Felt::new(13333398568519923717),
Felt::new(3596153696935337464),
Felt::new(8104208463525993784),
Felt::new(14345062289456085693),
Felt::new(17036731477169661256),
],
];
const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [
[
Felt::new(6077062762357204287),
Felt::new(15277620170502011191),
Felt::new(5358738125714196705),
Felt::new(14233283787297595718),
Felt::new(13792579614346651365),
Felt::new(11614812331536767105),
Felt::new(14871063686742261166),
Felt::new(10148237148793043499),
Felt::new(4457428952329675767),
Felt::new(15590786458219172475),
Felt::new(10063319113072092615),
Felt::new(14200078843431360086),
],
[
Felt::new(6202948458916099932),
Felt::new(17690140365333231091),
Felt::new(3595001575307484651),
Felt::new(373995945117666487),
Felt::new(1235734395091296013),
Felt::new(14172757457833931602),
Felt::new(707573103686350224),
Felt::new(15453217512188187135),
Felt::new(219777875004506018),
Felt::new(17876696346199469008),
Felt::new(17731621626449383378),
Felt::new(2897136237748376248),
],
[
Felt::new(8023374565629191455),
Felt::new(15013690343205953430),
Felt::new(4485500052507912973),
Felt::new(12489737547229155153),
Felt::new(9500452585969030576),
Felt::new(2054001340201038870),
Felt::new(12420704059284934186),
Felt::new(355990932618543755),
Felt::new(9071225051243523860),
Felt::new(12766199826003448536),
Felt::new(9045979173463556963),
Felt::new(12934431667190679898),
],
[
Felt::new(18389244934624494276),
Felt::new(16731736864863925227),
Felt::new(4440209734760478192),
Felt::new(17208448209698888938),
Felt::new(8739495587021565984),
Felt::new(17000774922218161967),
Felt::new(13533282547195532087),
Felt::new(525402848358706231),
Felt::new(16987541523062161972),
Felt::new(5466806524462797102),
Felt::new(14512769585918244983),
Felt::new(10973956031244051118),
],
[
Felt::new(6982293561042362913),
Felt::new(14065426295947720331),
Felt::new(16451845770444974180),
Felt::new(7139138592091306727),
Felt::new(9012006439959783127),
Felt::new(14619614108529063361),
Felt::new(1394813199588124371),
Felt::new(4635111139507788575),
Felt::new(16217473952264203365),
Felt::new(10782018226466330683),
Felt::new(6844229992533662050),
Felt::new(7446486531695178711),
],
[
Felt::new(3736792340494631448),
Felt::new(577852220195055341),
Felt::new(6689998335515779805),
Felt::new(13886063479078013492),
Felt::new(14358505101923202168),
Felt::new(7744142531772274164),
Felt::new(16135070735728404443),
Felt::new(12290902521256031137),
Felt::new(12059913662657709804),
Felt::new(16456018495793751911),
Felt::new(4571485474751953524),
Felt::new(17200392109565783176),
],
[
Felt::new(17130398059294018733),
Felt::new(519782857322261988),
Felt::new(9625384390925085478),
Felt::new(1664893052631119222),
Felt::new(7629576092524553570),
Felt::new(3485239601103661425),
Felt::new(9755891797164033838),
Felt::new(15218148195153269027),
Felt::new(16460604813734957368),
Felt::new(9643968136937729763),
Felt::new(3611348709641382851),
Felt::new(18256379591337759196),
],
];

View File

@@ -1,9 +1,11 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![no_std]
#[cfg(not(feature = "std"))]
#[cfg_attr(test, macro_use)]
#[macro_use]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
pub mod dsa;
pub mod hash;
pub mod merkle;
@@ -13,7 +15,10 @@ pub mod utils;
// RE-EXPORTS
// ================================================================================================
pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField};
pub use winter_math::{
fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension},
FieldElement, StarkField,
};
// TYPE ALIASES
// ================================================================================================

View File

@@ -1,20 +1,15 @@
use clap::Parser;
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::MerkleError,
Felt, Word, ONE,
{hash::rpo::Rpo256, merkle::TieredSmt},
};
use rand_utils::rand_value;
use std::time::Instant;
use clap::Parser;
use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt},
Felt, Word, ONE,
};
use rand_utils::rand_value;
#[derive(Parser, Debug)]
#[clap(
name = "Benchmark",
about = "Tiered SMT benchmark",
version,
rename_all = "kebab-case"
)]
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
pub struct BenchmarkCmd {
/// Size of the tree
#[clap(short = 's', long = "size")]
@@ -22,11 +17,11 @@ pub struct BenchmarkCmd {
}
fn main() {
benchmark_tsmt();
benchmark_smt();
}
/// Run a benchmark for the Tiered SMT.
pub fn benchmark_tsmt() {
/// Run a benchmark for [`Smt`].
pub fn benchmark_smt() {
let args = BenchmarkCmd::parse();
let tree_size = args.size;
@@ -40,41 +35,29 @@ pub fn benchmark_tsmt() {
let mut tree = construction(entries, tree_size).unwrap();
insertion(&mut tree, tree_size).unwrap();
batched_insertion(&mut tree, tree_size).unwrap();
proof_generation(&mut tree, tree_size).unwrap();
}
/// Runs the construction benchmark for the Tiered SMT, returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<TieredSmt, MerkleError> {
/// Runs the construction benchmark for [`Smt`], returning the constructed tree.
pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, MerkleError> {
println!("Running a construction benchmark:");
let now = Instant::now();
let tree = TieredSmt::with_entries(entries)?;
let tree = Smt::with_entries(entries)?;
let elapsed = now.elapsed();
println!(
"Constructed a TSMT with {} key-value pairs in {:.3} seconds",
"Constructed a SMT with {} key-value pairs in {:.3} seconds",
size,
elapsed.as_secs_f32(),
);
// Count how many nodes end up at each tier
let mut nodes_num_16_32_48 = (0, 0, 0);
tree.upper_leaf_nodes().for_each(|(index, _)| match index.depth() {
16 => nodes_num_16_32_48.0 += 1,
32 => nodes_num_16_32_48.1 += 1,
48 => nodes_num_16_32_48.2 += 1,
_ => unreachable!(),
});
println!("Number of nodes on depth 16: {}", nodes_num_16_32_48.0);
println!("Number of nodes on depth 32: {}", nodes_num_16_32_48.1);
println!("Number of nodes on depth 48: {}", nodes_num_16_32_48.2);
println!("Number of nodes on depth 64: {}\n", tree.bottom_leaves().count());
println!("Number of leaf nodes: {}\n", tree.leaves().count());
Ok(tree)
}
/// Runs the insertion benchmark for the Tiered SMT.
pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
/// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running an insertion benchmark:");
let mut insertion_times = Vec::new();
@@ -90,9 +73,9 @@ pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
}
println!(
"An average insertion time measured by 20 inserts into a TSMT with {} key-value pairs is {:.3} milliseconds\n",
"An average insertion time measured by 20 inserts into a SMT with {} key-value pairs is {:.3} milliseconds\n",
size,
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
// calculate the average by dividing by 20 and convert to milliseconds by multiplying by
// 1000. As a result, we can only multiply by 50
insertion_times.iter().sum::<f32>() * 50f32,
);
@@ -100,8 +83,56 @@ pub fn insertion(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
Ok(())
}
/// Runs the proof generation benchmark for the Tiered SMT.
pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleError> {
pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running a batched insertion benchmark:");
let new_pairs: Vec<(RpoDigest, Word)> = (0..1000)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new(size + i)];
(key, value)
})
.collect();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed();
let now = Instant::now();
tree.apply_mutations(mutations).unwrap();
let apply_elapsed = now.elapsed();
println!(
"An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
compute_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
compute_elapsed.as_secs_f32(),
);
println!(
"An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
apply_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
apply_elapsed.as_secs_f32(),
);
println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds",
size,
(compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32,
);
println!();
Ok(())
}
/// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running a proof generation benchmark:");
let mut insertion_times = Vec::new();
@@ -112,13 +143,13 @@ pub fn proof_generation(tree: &mut TieredSmt, size: u64) -> Result<(), MerkleErr
tree.insert(test_key, test_value);
let now = Instant::now();
let _proof = tree.prove(test_key);
let _proof = tree.open(&test_key);
let elapsed = now.elapsed();
insertion_times.push(elapsed.as_secs_f32());
}
println!(
"An average proving time measured by 20 value proofs in a TSMT with {} key-value pairs in {:.3} microseconds",
"An average proving time measured by 20 value proofs in a SMT with {} key-value pairs in {:.3} microseconds",
size,
// calculate the average by dividing by 20 and convert to microseconds by multiplying by
// 1000000. As a result, we can only multiply by 50000

View File

@@ -1,156 +0,0 @@
use super::{
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
};
use crate::utils::collections::Diff;
#[cfg(test)]
use super::{super::ONE, Felt, SimpleSmt, EMPTY_WORD, ZERO};
// MERKLE STORE DELTA
// ================================================================================================
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
/// differences between the initial and final Merkle tree states.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
// MERKLE TREE DELTA
// ================================================================================================
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
///
/// The differences are represented as follows:
/// - depth: the depth of the merkle tree.
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
#[cfg(not(test))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTreeDelta {
depth: u8,
cleared_slots: Vec<u64>,
updated_slots: Vec<(u64, Word)>,
}
impl MerkleTreeDelta {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
pub fn new(depth: u8) -> Self {
Self {
depth,
cleared_slots: Vec::new(),
updated_slots: Vec::new(),
}
}
// ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
pub fn depth(&self) -> u8 {
self.depth
}
/// Returns the indexes of slots where values were set to [ZERO; 4].
pub fn cleared_slots(&self) -> &[u64] {
&self.cleared_slots
}
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
pub fn updated_slots(&self) -> &[(u64, Word)] {
&self.updated_slots
}
// MODIFIERS
// --------------------------------------------------------------------------------------------
/// Adds a slot index to the list of cleared slots.
pub fn add_cleared_slot(&mut self, index: u64) {
self.cleared_slots.push(index);
}
/// Adds a slot index and a value to the list of updated slots.
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
self.updated_slots.push((index, value));
}
}
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
/// their roots and depth.
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
tree_root_1: RpoDigest,
tree_root_2: RpoDigest,
depth: u8,
merkle_store: &MerkleStore<T>,
) -> Result<MerkleTreeDelta, MerkleError> {
if tree_root_1 == tree_root_2 {
return Ok(MerkleTreeDelta::new(depth));
}
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
let diff = tree_1_leaves.diff(&tree_2_leaves);
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
Ok(MerkleTreeDelta {
depth,
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
updated_slots: diff
.updated
.into_iter()
.map(|(index, leaf)| (index.value(), *leaf))
.collect(),
})
}
// INTERNALS
// --------------------------------------------------------------------------------------------
#[cfg(test)]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MerkleTreeDelta {
pub depth: u8,
pub cleared_slots: Vec<u64>,
pub updated_slots: Vec<(u64, Word)>,
}
// MERKLE DELTA
// ================================================================================================
#[test]
fn test_compute_merkle_delta() {
let entries = vec![
(10, [ZERO, ONE, Felt::new(2), Felt::new(3)]),
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
];
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
let mut store: MerkleStore = (&simple_smt).into();
let root = simple_smt.root();
// add a new node
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
// update an existing node
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
// remove a node
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
let merkle_delta =
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
let expected_merkle_delta = MerkleTreeDelta {
depth: simple_smt.depth(),
cleared_slots: vec![remove_idx.value()],
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
};
assert_eq!(merkle_delta, expected_merkle_delta);
}

View File

@@ -1,6 +1,7 @@
use super::{Felt, RpoDigest, EMPTY_WORD};
use core::slice;
use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD};
// EMPTY NODES SUBTREES
// ================================================================================================
@@ -10,12 +11,30 @@ pub struct EmptySubtreeRoots;
impl EmptySubtreeRoots {
/// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the
/// specified depth.
pub const fn empty_hashes(depth: u8) -> &'static [RpoDigest] {
let ptr = &EMPTY_SUBTREES[255 - depth as usize] as *const RpoDigest;
pub const fn empty_hashes(tree_depth: u8) -> &'static [RpoDigest] {
let ptr = &EMPTY_SUBTREES[255 - tree_depth as usize] as *const RpoDigest;
// Safety: this is a static/constant array, so it will never be outlived. If we attempt to
// use regular slices, this wouldn't be a `const` function, meaning we won't be able to use
// the returned value for static/constant definitions.
unsafe { slice::from_raw_parts(ptr, depth as usize + 1) }
unsafe { slice::from_raw_parts(ptr, tree_depth as usize + 1) }
}
/// Returns the node's digest for a sub-tree with all its leaves set to the empty word.
pub const fn entry(tree_depth: u8, node_depth: u8) -> &'static RpoDigest {
assert!(node_depth <= tree_depth);
let pos = 255 - tree_depth + node_depth;
&EMPTY_SUBTREES[pos as usize]
}
/// Returns a sparse Merkle tree [`InnerNode`] with two empty children.
///
/// # Note
/// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth`
/// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for
/// subtrees of depth `node_depth + 1`.
pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode {
let &child = Self::entry(tree_depth, node_depth + 1);
InnerNode { left: child, right: child }
}
}
@@ -1583,3 +1602,16 @@ fn all_depths_opens_to_zero() {
.for_each(|(x, computed)| assert_eq!(x, computed));
}
}
#[test]
fn test_entry() {
// check the leaf is always the empty work
for depth in 0..255 {
assert_eq!(EmptySubtreeRoots::entry(depth, depth), &RpoDigest::new(EMPTY_WORD));
}
// check the root matches the first element of empty_hashes
for depth in 0..255 {
assert_eq!(EmptySubtreeRoots::entry(depth, 0), &EmptySubtreeRoots::empty_hashes(depth)[0]);
}
}

View File

@@ -1,54 +1,34 @@
use crate::{
merkle::{MerklePath, NodeIndex, RpoDigest},
utils::collections::Vec,
};
use core::fmt;
use thiserror::Error;
#[derive(Clone, Debug, PartialEq, Eq)]
use super::{NodeIndex, RpoDigest};
#[derive(Debug, Error)]
pub enum MerkleError {
ConflictingRoots(Vec<RpoDigest>),
#[error("expected merkle root {expected_root} found {actual_root}")]
ConflictingRoots {
expected_root: RpoDigest,
actual_root: RpoDigest,
},
#[error("provided merkle tree depth {0} is too small")]
DepthTooSmall(u8),
#[error("provided merkle tree depth {0} is too big")]
DepthTooBig(u64),
#[error("multiple values provided for merkle tree index {0}")]
DuplicateValuesForIndex(u64),
DuplicateValuesForKey(RpoDigest),
InvalidIndex { depth: u8, value: u64 },
InvalidDepth { expected: u8, provided: u8 },
InvalidPath(MerklePath),
InvalidNumEntries(usize, usize),
NodeNotInSet(NodeIndex),
NodeNotInStore(RpoDigest, NodeIndex),
#[error("node index value {value} is not valid for depth {depth}")]
InvalidNodeIndex { depth: u8, value: u64 },
#[error("provided node index depth {provided} does not match expected depth {expected}")]
InvalidNodeIndexDepth { expected: u8, provided: u8 },
#[error("merkle subtree depth {subtree_depth} exceeds merkle tree depth {tree_depth}")]
SubtreeDepthExceedsDepth { subtree_depth: u8, tree_depth: u8 },
#[error("number of entries in the merkle tree exceeds the maximum of {0}")]
TooManyEntries(usize),
#[error("node index `{0}` not found in the tree")]
NodeIndexNotFoundInTree(NodeIndex),
#[error("node {0:?} with index `{1}` not found in the store")]
NodeIndexNotFoundInStore(RpoDigest, NodeIndex),
#[error("number of provided merkle tree leaves {0} is not a power of two")]
NumLeavesNotPowerOfTwo(usize),
#[error("root {0:?} is not in the store")]
RootNotInStore(RpoDigest),
}
impl fmt::Display for MerkleError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use MerkleError::*;
match self {
ConflictingRoots(roots) => write!(f, "the merkle paths roots do not match {roots:?}"),
DepthTooSmall(depth) => write!(f, "the provided depth {depth} is too small"),
DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"),
DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"),
DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"),
InvalidIndex{ depth, value} => write!(
f,
"the index value {value} is not valid for the depth {depth}"
),
InvalidDepth { expected, provided } => write!(
f,
"the provided depth {provided} is not valid for {expected}"
),
InvalidPath(_path) => write!(f, "the provided path is not valid"),
InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"),
NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"),
NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"),
NumLeavesNotPowerOfTwo(leaves) => {
write!(f, "the leaves count {leaves} is not a power of 2")
}
RootNotInStore(root) => write!(f, "the root {:?} is not in the store", root),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for MerkleError {}

View File

@@ -1,7 +1,8 @@
use super::{Felt, MerkleError, RpoDigest, StarkField};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
use core::fmt::Display;
use super::{Felt, MerkleError, RpoDigest};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
// NODE INDEX
// ================================================================================================
@@ -37,7 +38,7 @@ impl NodeIndex {
/// Returns an error if the `value` is greater than or equal to 2^{depth}.
pub const fn new(depth: u8, value: u64) -> Result<Self, MerkleError> {
if (64 - value.leading_zeros()) > depth as u32 {
Err(MerkleError::InvalidIndex { depth, value })
Err(MerkleError::InvalidNodeIndex { depth, value })
} else {
Ok(Self { depth, value })
}
@@ -96,6 +97,14 @@ impl NodeIndex {
self
}
/// Returns the parent of the current node. This is the same as [`Self::move_up()`], but returns
/// a new value instead of mutating `self`.
pub const fn parent(mut self) -> Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
@@ -181,19 +190,28 @@ impl Deserializable for NodeIndex {
#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use proptest::prelude::*;
use super::*;
#[test]
fn test_node_index_value_too_high() {
assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 });
match NodeIndex::new(0, 1) {
Err(MerkleError::InvalidIndex { depth, value }) => {
assert_eq!(depth, 0);
assert_eq!(value, 1);
}
_ => unreachable!(),
}
let err = NodeIndex::new(0, 1).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 0, value: 1 });
assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 });
let err = NodeIndex::new(1, 2).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 1, value: 2 });
assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 });
let err = NodeIndex::new(2, 4).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 2, value: 4 });
assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 });
let err = NodeIndex::new(3, 8).unwrap_err();
assert_matches!(err, MerkleError::InvalidNodeIndex { depth: 3, value: 8 });
}
#[test]

View File

@@ -1,7 +1,8 @@
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Vec, Word};
use crate::utils::{string::String, uninit_vector, word_to_hex};
use alloc::{string::String, vec::Vec};
use core::{fmt, ops::Deref, slice};
use winter_math::log2;
use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
use crate::utils::{uninit_vector, word_to_hex};
// MERKLE TREE
// ================================================================================================
@@ -20,7 +21,11 @@ impl MerkleTree {
///
/// # Errors
/// Returns an error if the number of leaves is smaller than two or is not a power of two.
pub fn new(leaves: Vec<Word>) -> Result<Self, MerkleError> {
pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
where
T: AsRef<[Word]>,
{
let leaves = leaves.as_ref();
let n = leaves.len();
if n <= 1 {
return Err(MerkleError::DepthTooSmall(n as u8));
@@ -34,7 +39,7 @@ impl MerkleTree {
// copy leaves into the second part of the nodes vector
nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
*node = RpoDigest::from(leaf);
*node = RpoDigest::from(*leaf);
});
// re-interpret nodes as an array of two nodes fused together
@@ -63,7 +68,7 @@ impl MerkleTree {
///
/// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
pub fn depth(&self) -> u8 {
log2(self.nodes.len() / 2) as u8
(self.nodes.len() / 2).ilog2() as u8
}
/// Returns a node at the specified depth and index value.
@@ -175,6 +180,26 @@ impl MerkleTree {
}
}
// CONVERSIONS
// ================================================================================================
impl TryFrom<&[Word]> for MerkleTree {
type Error = MerkleError;
fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
MerkleTree::new(value)
}
}
impl TryFrom<&[RpoDigest]> for MerkleTree {
type Error = MerkleError;
fn try_from(value: &[RpoDigest]) -> Result<Self, Self::Error> {
let value: Vec<Word> = value.iter().map(|v| *v.deref()).collect();
MerkleTree::new(value)
}
}
// ITERATORS
// ================================================================================================
@@ -186,7 +211,7 @@ pub struct InnerNodeIterator<'a> {
index: usize,
}
impl<'a> Iterator for InnerNodeIterator<'a> {
impl Iterator for InnerNodeIterator<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
@@ -259,13 +284,15 @@ pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
#[cfg(test)]
mod tests {
use core::mem::size_of;
use proptest::prelude::*;
use super::*;
use crate::{
merkle::{digests_to_words, int_to_leaf, int_to_node, InnerNodeInfo},
Felt, Word, WORD_SIZE,
merkle::{digests_to_words, int_to_leaf, int_to_node},
Felt, WORD_SIZE,
};
use core::mem::size_of;
use proptest::prelude::*;
const LEAVES4: [RpoDigest; WORD_SIZE] =
[int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];

View File

@@ -1,91 +0,0 @@
use super::{
super::{RpoDigest, Vec, ZERO},
Felt, MmrProof, Rpo256, Word,
};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrPeaks {
/// The number of leaves is used to differentiate accumulators that have the same number of
/// peaks. This happens because the number of peaks goes up-and-down as the structure is used
/// causing existing trees to be merged and new ones to be created. As an example, every time
/// the MMR has a power-of-two number of leaves there is a single peak.
///
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the bits in
/// `num_leaves` conveniently encode the size of each individual tree.
///
/// Examples:
///
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number
/// of peaks, in this case there are 2 peaks. The 0-indexed least-significant position of
/// the bit determines the number of elements of a tree, so the rightmost tree has `2**0`
/// elements and the left most has `2**2`.
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the
/// leftmost tree has `2**3=8` elements, and the right most has `2**2=4` elements.
pub num_leaves: usize,
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
/// leaves, starting from the peak with most children, to the one with least.
///
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
pub peaks: Vec<RpoDigest>,
}
impl MmrPeaks {
/// Hashes the peaks.
///
/// The procedure will:
/// - Flatten and pad the peaks to a vector of Felts.
/// - Hash the vector of Felts.
pub fn hash_peaks(&self) -> Word {
Rpo256::hash_elements(&self.flatten_and_pad_peaks()).into()
}
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool {
let root = &self.peaks[opening.peak_index()];
opening.merkle_path.verify(opening.relative_pos() as u64, value, root)
}
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
///
/// The procedure will:
/// - Flatten the vector of Words into a vector of Felts.
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
/// padding.
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of
/// hashing.
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
let num_peaks = self.peaks.len();
// To achieve the padding rules above we calculate the length of the final vector.
// This is calculated as the number of field elements. Each peak is 4 field elements.
// The length is calculated as follows:
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires
// 64 field elements.
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded
// and as such requires `num_peaks * 4` field elements.
let len = if num_peaks < 16 {
64
} else if num_peaks % 2 == 1 {
(num_peaks + 1) * 4
} else {
num_peaks * 4
};
let mut elements = Vec::with_capacity(len);
elements.extend_from_slice(
&self
.peaks
.as_slice()
.iter()
.map(|digest| digest.into())
.collect::<Vec<Word>>()
.concat(),
);
elements.resize(len, ZERO);
elements
}
}

18
src/merkle/mmr/delta.rs Normal file
View File

@@ -0,0 +1,18 @@
use alloc::vec::Vec;
use super::super::RpoDigest;
/// Container for the update data of a [super::PartialMmr]
#[derive(Debug)]
pub struct MmrDelta {
/// The new version of the [super::Mmr]
pub forest: usize,
/// Update data.
///
/// The data is packed as follows:
/// 1. All the elements needed to perform authentication path updates. These are the right
/// siblings required to perform tree merges on the [super::PartialMmr].
/// 2. The new peaks.
pub data: Vec<RpoDigest>,
}

27
src/merkle/mmr/error.rs Normal file
View File

@@ -0,0 +1,27 @@
use alloc::string::String;
use thiserror::Error;
use crate::merkle::MerkleError;
#[derive(Debug, Error)]
pub enum MmrError {
#[error("mmr does not contain position {0}")]
PositionNotFound(usize),
#[error("mmr peaks are invalid: {0}")]
InvalidPeaks(String),
#[error(
"mmr peak does not match the computed merkle root of the provided authentication path"
)]
PeakPathMismatch,
#[error("requested peak index is {peak_idx} but the number of peaks is {peaks_len}")]
PeakOutOfBounds { peak_idx: usize, peaks_len: usize },
#[error("invalid mmr update")]
InvalidUpdate,
#[error("mmr does not contain a peak with depth {0}")]
UnknownPeak(u8),
#[error("invalid merkle path")]
InvalidMerklePath(#[source] MerkleError),
#[error("merkle root computation failed")]
MerkleRootComputationFailed(#[source] MerkleError),
}

View File

@@ -7,18 +7,17 @@
//!
//! Additionally the structure only supports adding leaves to the right-most tree, the one with the
//! least number of leaves. The structure preserves the invariant that each tree has different
//! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are
//! depths, i.e. as part of adding a new element to the forest the trees with same depth are
//! merged, creating a new tree with depth d+1, this process is continued until the property is
//! restabilished.
use super::{
super::{InnerNodeInfo, MerklePath, RpoDigest, Vec},
bit::TrueBitPositionIterator,
MmrPeaks, MmrProof, Rpo256,
};
use core::fmt::{Display, Formatter};
//! reestablished.
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::error::Error;
use super::{
super::{InnerNodeInfo, MerklePath},
bit::TrueBitPositionIterator,
leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256,
RpoDigest,
};
// MMR
// ===============================================================================================
@@ -43,22 +42,6 @@ pub struct Mmr {
pub(super) nodes: Vec<RpoDigest>,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum MmrError {
InvalidPosition(usize),
}
impl Display for MmrError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match self {
MmrError::InvalidPosition(pos) => write!(fmt, "Mmr does not contain position {pos}"),
}
}
}
#[cfg(feature = "std")]
impl Error for MmrError {}
impl Default for Mmr {
fn default() -> Self {
Self::new()
@@ -90,34 +73,46 @@ impl Mmr {
// FUNCTIONALITY
// ============================================================================================
/// Given a leaf position, returns the Merkle path to its corresponding peak. If the position
/// is greater-or-equal than the tree size an error is returned.
/// Returns an [MmrProof] for the leaf at the specified position.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if the specified leaf position is out of bounds for this MMR.
pub fn open(&self, pos: usize) -> Result<MmrProof, MmrError> {
self.open_at(pos, self.forest)
}
/// Returns an [MmrProof] for the leaf at the specified position using the state of the MMR
/// at the specified `forest`.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if:
/// - The specified leaf position is out of bounds for this MMR.
/// - The specified `forest` value is not valid for this MMR.
pub fn open_at(&self, pos: usize, forest: usize) -> Result<MmrProof, MmrError> {
// find the target tree responsible for the MMR position
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
let forest_target = 1usize << tree_bit;
leaf_to_corresponding_tree(pos, forest).ok_or(MmrError::PositionNotFound(pos))?;
// isolate the trees before the target
let forest_before = self.forest & high_bitmask(tree_bit + 1);
let forest_before = forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before);
// find the root
let index = nodes_in_forest(forest_target) - 1;
// update the value position from global to the target tree
let relative_pos = pos - forest_before;
// collect the path and the final index of the target value
let (_, path) =
self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset, index);
let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
Ok(MmrProof {
forest: self.forest,
forest,
position: pos,
merkle_path: MerklePath::new(path),
})
@@ -131,22 +126,17 @@ impl Mmr {
pub fn get(&self, pos: usize) -> Result<RpoDigest, MmrError> {
// find the target tree responsible for the MMR position
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?;
let forest_target = 1usize << tree_bit;
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
// isolate the trees before the target
let forest_before = self.forest & high_bitmask(tree_bit + 1);
let index_offset = nodes_in_forest(forest_before);
// find the root
let index = nodes_in_forest(forest_target) - 1;
// update the value position from global to the target tree
let relative_pos = pos - forest_before;
// collect the path and the final index of the target value
let (value, _) =
self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset, index);
let (value, _) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset);
Ok(value)
}
@@ -173,9 +163,24 @@ impl Mmr {
self.forest += 1;
}
/// Returns an accumulator representing the current state of the MMR.
pub fn accumulator(&self) -> MmrPeaks {
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(self.forest)
/// Returns the current peaks of the MMR.
pub fn peaks(&self) -> MmrPeaks {
self.peaks_at(self.forest).expect("failed to get peaks at current forest")
}
/// Returns the peaks of the MMR at the state specified by `forest`.
///
/// # Errors
/// Returns an error if the specified `forest` value is not valid for this MMR.
pub fn peaks_at(&self, forest: usize) -> Result<MmrPeaks, MmrError> {
if forest > self.forest {
return Err(MmrError::InvalidPeaks(format!(
"requested forest {forest} exceeds current forest {}",
self.forest
)));
}
let peaks: Vec<RpoDigest> = TrueBitPositionIterator::new(forest)
.rev()
.map(|bit| nodes_in_forest(1 << bit))
.scan(0, |offset, el| {
@@ -185,7 +190,84 @@ impl Mmr {
.map(|offset| self.nodes[offset - 1])
.collect();
MmrPeaks { num_leaves: self.forest, peaks }
// Safety: the invariant is maintained by the [Mmr]
let peaks = MmrPeaks::new(forest, peaks).unwrap();
Ok(peaks)
}
/// Compute the required update to `original_forest`.
///
/// The result is a packed sequence of the authentication elements required to update the trees
/// that have been merged together, followed by the new peaks of the [Mmr].
pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result<MmrDelta, MmrError> {
if to_forest > self.forest || from_forest > to_forest {
return Err(MmrError::InvalidPeaks(format!("to_forest {to_forest} exceeds the current forest {} or from_forest {from_forest} exceeds to_forest", self.forest)));
}
if from_forest == to_forest {
return Ok(MmrDelta { forest: to_forest, data: Vec::new() });
}
let mut result = Vec::new();
// Find the largest tree in this [Mmr] which is new to `from_forest`.
let candidate_trees = to_forest ^ from_forest;
let mut new_high = 1 << candidate_trees.ilog2();
// Collect authentication nodes used for tree merges
// ----------------------------------------------------------------------------------------
// Find the trees from `from_forest` that have been merged into `new_high`.
let mut merges = from_forest & (new_high - 1);
// Find the peaks that are common to `from_forest` and this [Mmr]
let common_trees = from_forest ^ merges;
if merges != 0 {
// Skip the smallest trees unknown to `from_forest`.
let mut target = 1 << merges.trailing_zeros();
// Collect siblings required to computed the merged tree's peak
while target < new_high {
// Computes the offset to the smallest know peak
// - common_trees: peaks unchanged in the current update, target comes after these.
// - merges: peaks that have not been merged so far, target comes after these.
// - target: tree from which to load the sibling. On the first iteration this is a
// value known by the partial mmr, on subsequent iterations this value is to be
// computed from the known peaks and provided authentication nodes.
let known = nodes_in_forest(common_trees | merges | target);
let sibling = nodes_in_forest(target);
result.push(self.nodes[known + sibling - 1]);
// Update the target and account for tree merges
target <<= 1;
while merges & target != 0 {
target <<= 1;
}
// Remove the merges done so far
merges ^= merges & (target - 1);
}
} else {
// The new high tree may not be the result of any merges, if it is smaller than all the
// trees of `from_forest`.
new_high = 0;
}
// Collect the new [Mmr] peaks
// ----------------------------------------------------------------------------------------
let mut new_peaks = to_forest ^ common_trees ^ new_high;
let old_peaks = to_forest ^ new_peaks;
let mut offset = nodes_in_forest(old_peaks);
while new_peaks != 0 {
let target = 1 << new_peaks.ilog2();
offset += nodes_in_forest(target);
result.push(self.nodes[offset - 1]);
new_peaks ^= target;
}
Ok(MmrDelta { forest: to_forest, data: result })
}
/// An iterator over inner nodes in the MMR. The order of iteration is unspecified.
@@ -202,36 +284,52 @@ impl Mmr {
// ============================================================================================
/// Internal function used to collect the Merkle path of a value.
///
/// The arguments are relative to the target tree. To compute the opening of the second leaf
/// for a tree with depth 2 in the forest `0b110`:
///
/// - `tree_bit`: Depth of the target tree, e.g. 2 for the smallest tree.
/// - `relative_pos`: 0-indexed leaf position in the target tree, e.g. 1 for the second leaf.
/// - `index_offset`: Node count prior to the target tree, e.g. 7 for the tree of depth 3.
fn collect_merkle_path_and_value(
&self,
tree_bit: u32,
relative_pos: usize,
index_offset: usize,
mut index: usize,
) -> (RpoDigest, Vec<RpoDigest>) {
// collect the Merkle path
let mut tree_depth = tree_bit as usize;
let mut path = Vec::with_capacity(tree_depth + 1);
while tree_depth > 0 {
let bit = relative_pos & tree_depth;
let right_offset = index - 1;
let left_offset = right_offset - nodes_in_forest(tree_depth);
// see documentation of `leaf_to_corresponding_tree` for details
let tree_depth = (tree_bit + 1) as usize;
let mut path = Vec::with_capacity(tree_depth);
// Elements to the right have a higher position because they were
// added later. Therefore when the bit is true the node's path is
// to the right, and its sibling to the left.
let sibling = if bit != 0 {
// The tree walk below goes from the root to the leaf, compute the root index to start
let mut forest_target = 1usize << tree_bit;
let mut index = nodes_in_forest(forest_target) - 1;
// Loop until the leaf is reached
while forest_target > 1 {
// Update the depth of the tree to correspond to a subtree
forest_target >>= 1;
// compute the indices of the right and left subtrees based on the post-order
let right_offset = index - 1;
let left_offset = right_offset - nodes_in_forest(forest_target);
let left_or_right = relative_pos & forest_target;
let sibling = if left_or_right != 0 {
// going down the right subtree, the right child becomes the new root
index = right_offset;
// and the left child is the authentication
self.nodes[index_offset + left_offset]
} else {
index = left_offset;
self.nodes[index_offset + right_offset]
};
tree_depth >>= 1;
path.push(sibling);
}
debug_assert!(path.len() == tree_depth - 1);
// the rest of the codebase has the elements going from leaf to root, adjust it here for
// easy of use/consistency sake
path.reverse();
@@ -241,6 +339,9 @@ impl Mmr {
}
}
// CONVERSIONS
// ================================================================================================
impl<T> From<T> for Mmr
where
T: IntoIterator<Item = RpoDigest>,
@@ -272,7 +373,7 @@ pub struct MmrNodes<'a> {
index: usize,
}
impl<'a> Iterator for MmrNodes<'a> {
impl Iterator for MmrNodes<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
@@ -305,7 +406,8 @@ impl<'a> Iterator for MmrNodes<'a> {
// the next parent position is one above the position of the pair
let parent = self.last_right << 1;
// the left node has been paired and the current parent yielded, removed it from the forest
// the left node has been paired and the current parent yielded, removed it from the
// forest
self.forest ^= self.last_right;
if self.forest & parent == 0 {
// this iteration yielded the left parent node
@@ -335,32 +437,6 @@ impl<'a> Iterator for MmrNodes<'a> {
// UTILITIES
// ===============================================================================================
/// Given a 0-indexed leaf position and the current forest, return the tree number responsible for
/// the position.
///
/// Note:
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
/// the tree, which corresponds to the size of a Merkle proof for that tree. $2^p$ is equal to the
/// number of leaves in this particular tree. and $2^(p+1)-1$ corresponds to size of the tree.
pub(crate) const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
if pos >= forest {
None
} else {
// - each bit in the forest is a unique tree and the bit position its power-of-two size
// - each tree owns a consecutive range of positions equal to its size from left-to-right
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
// `k_1` is the second higest bit, so on.
// - this means the highest bits work as a category marker, and the position is owned by
// the first tree which doesn't share a high bit with the position
let before = forest & pos;
let after = forest ^ before;
let tree = after.ilog2();
Some(tree)
}
}
/// Return a bitmask for the bits including and above the given position.
pub(crate) const fn high_bitmask(bit: u32) -> usize {
if bit > usize::BITS - 1 {
@@ -369,17 +445,3 @@ pub(crate) const fn high_bitmask(bit: u32) -> usize {
usize::MAX << bit
}
}
/// Return the total number of nodes of a given forest
///
/// Panics:
///
/// This will panic if the forest has size greater than `usize::MAX / 2`
pub(crate) const fn nodes_in_forest(forest: usize) -> usize {
// - the size of a perfect binary tree is $2^{k+1}-1$ or $2*2^k-1$
// - the forest represents the sum of $2^k$ so a single multiplication is necessary
// - the number of `-1` is the same as the number of trees, which is the same as the number
// bits set
let tree_count = forest.count_ones() as usize;
forest * 2 - tree_count
}

191
src/merkle/mmr/inorder.rs Normal file
View File

@@ -0,0 +1,191 @@
//! Index for nodes of a binary tree based on an in-order tree walk.
//!
//! In-order walks have the parent node index split its left and right subtrees. All the left
//! children have indexes lower than the parent, meanwhile all the right subtree higher indexes.
//! This property makes it is easy to compute changes to the index by adding or subtracting the
//! leaves count.
use core::num::NonZeroUsize;
use winter_utils::{Deserializable, Serializable};
// IN-ORDER INDEX
// ================================================================================================
/// Index of nodes in a perfectly balanced binary tree based on an in-order tree walk.
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct InOrderIndex {
idx: usize,
}
impl InOrderIndex {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [InOrderIndex] instantiated from the provided value.
pub fn new(idx: NonZeroUsize) -> InOrderIndex {
InOrderIndex { idx: idx.get() }
}
/// Return a new [InOrderIndex] instantiated from the specified leaf position.
///
/// # Panics:
/// If `leaf` is higher than or equal to `usize::MAX / 2`.
pub fn from_leaf_pos(leaf: usize) -> InOrderIndex {
// Convert the position from 0-indexed to 1-indexed, since the bit manipulation in this
// implementation only works 1-indexed counting.
let pos = leaf + 1;
InOrderIndex { idx: pos * 2 - 1 }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// True if the index is pointing at a leaf.
///
/// Every odd number represents a leaf.
pub fn is_leaf(&self) -> bool {
self.idx & 1 == 1
}
/// Returns true if this note is a left child of its parent.
pub fn is_left_child(&self) -> bool {
self.parent().left_child() == *self
}
/// Returns the level of the index.
///
/// Starts at level zero for leaves and increases by one for each parent.
pub fn level(&self) -> u32 {
self.idx.trailing_zeros()
}
/// Returns the index of the left child.
///
/// # Panics:
/// If the index corresponds to a leaf.
pub fn left_child(&self) -> InOrderIndex {
// The left child is itself a parent, with an index that splits its left/right subtrees. To
// go from the parent index to its left child, it is only necessary to subtract the count
// of elements on the child's right subtree + 1.
let els = 1 << (self.level() - 1);
InOrderIndex { idx: self.idx - els }
}
/// Returns the index of the right child.
///
/// # Panics:
/// If the index corresponds to a leaf.
pub fn right_child(&self) -> InOrderIndex {
// To compute the index of the parent of the right subtree it is sufficient to add the size
// of its left subtree + 1.
let els = 1 << (self.level() - 1);
InOrderIndex { idx: self.idx + els }
}
/// Returns the index of the parent node.
pub fn parent(&self) -> InOrderIndex {
// If the current index corresponds to a node in a left tree, to go up a level it is
// required to add the number of nodes of the right sibling, analogously if the node is a
// right child, going up requires subtracting the number of nodes in its left subtree.
//
// Both of the above operations can be performed by bitwise manipulation. Below the mask
// sets the number of trailing zeros to be equal the new level of the index, and the bit
// marks the parent.
let target = self.level() + 1;
let bit = 1 << target;
let mask = bit - 1;
let idx = self.idx ^ (self.idx & mask);
InOrderIndex { idx: idx | bit }
}
/// Returns the index of the sibling node.
pub fn sibling(&self) -> InOrderIndex {
let parent = self.parent();
if *self > parent {
parent.left_child()
} else {
parent.right_child()
}
}
/// Returns the inner value of this [InOrderIndex].
pub fn inner(&self) -> u64 {
self.idx as u64
}
}
impl Serializable for InOrderIndex {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
target.write_usize(self.idx);
}
}
impl Deserializable for InOrderIndex {
fn read_from<R: winter_utils::ByteReader>(
source: &mut R,
) -> Result<Self, winter_utils::DeserializationError> {
let idx = source.read_usize()?;
Ok(InOrderIndex { idx })
}
}
// CONVERSIONS FROM IN-ORDER INDEX
// ------------------------------------------------------------------------------------------------
impl From<InOrderIndex> for u64 {
fn from(index: InOrderIndex) -> Self {
index.idx as u64
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod test {
use proptest::prelude::*;
use winter_utils::{Deserializable, Serializable};
use super::InOrderIndex;
proptest! {
#[test]
fn proptest_inorder_index_random(count in 1..1000usize) {
let left_pos = count * 2;
let right_pos = count * 2 + 1;
let left = InOrderIndex::from_leaf_pos(left_pos);
let right = InOrderIndex::from_leaf_pos(right_pos);
assert!(left.is_leaf());
assert!(right.is_leaf());
assert_eq!(left.parent(), right.parent());
assert_eq!(left.parent().right_child(), right);
assert_eq!(left, right.parent().left_child());
assert_eq!(left.sibling(), right);
assert_eq!(left, right.sibling());
}
}
#[test]
fn test_inorder_index_basic() {
let left = InOrderIndex::from_leaf_pos(0);
let right = InOrderIndex::from_leaf_pos(1);
assert!(left.is_leaf());
assert!(right.is_leaf());
assert_eq!(left.parent(), right.parent());
assert_eq!(left.parent().right_child(), right);
assert_eq!(left, right.parent().left_child());
assert_eq!(left.sibling(), right);
assert_eq!(left, right.sibling());
}
#[test]
fn test_inorder_index_serialization() {
let index = InOrderIndex::from_leaf_pos(5);
let bytes = index.to_bytes();
let index2 = InOrderIndex::read_from_bytes(&bytes).unwrap();
assert_eq!(index, index2);
}
}

View File

@@ -1,15 +1,67 @@
mod accumulator;
mod bit;
mod delta;
mod error;
mod full;
mod inorder;
mod partial;
mod peaks;
mod proof;
#[cfg(test)]
mod tests;
use super::{Felt, Rpo256, Word};
// REEXPORTS
// ================================================================================================
pub use accumulator::MmrPeaks;
pub use delta::MmrDelta;
pub use error::MmrError;
pub use full::Mmr;
pub use inorder::InOrderIndex;
pub use partial::PartialMmr;
pub use peaks::MmrPeaks;
pub use proof::MmrProof;
use super::{Felt, Rpo256, RpoDigest, Word};
// UTILITIES
// ===============================================================================================
/// Given a 0-indexed leaf position and the current forest, return the tree number responsible for
/// the position.
///
/// Note:
/// The result is a tree position `p`, it has the following interpretations. $p+1$ is the depth of
/// the tree. Because the root element is not part of the proof, $p$ is the length of the
/// authentication path. $2^p$ is equal to the number of leaves in this particular tree. and
/// $2^(p+1)-1$ corresponds to size of the tree.
const fn leaf_to_corresponding_tree(pos: usize, forest: usize) -> Option<u32> {
if pos >= forest {
None
} else {
// - each bit in the forest is a unique tree and the bit position its power-of-two size
// - each tree owns a consecutive range of positions equal to its size from left-to-right
// - this means the first tree owns from `0` up to the `2^k_0` first positions, where `k_0`
// is the highest true bit position, the second tree from `2^k_0 + 1` up to `2^k_1` where
// `k_1` is the second highest bit, so on.
// - this means the highest bits work as a category marker, and the position is owned by the
// first tree which doesn't share a high bit with the position
let before = forest & pos;
let after = forest ^ before;
let tree = after.ilog2();
Some(tree)
}
}
/// Return the total number of nodes of a given forest
///
/// Panics:
///
/// This will panic if the forest has size greater than `usize::MAX / 2`
const fn nodes_in_forest(forest: usize) -> usize {
// - the size of a perfect binary tree is $2^{k+1}-1$ or $2*2^k-1$
// - the forest represents the sum of $2^k$ so a single multiplication is necessary
// - the number of `-1` is the same as the number of trees, which is the same as the number
// bits set
let tree_count = forest.count_ones() as usize;
forest * 2 - tree_count
}

952
src/merkle/mmr/partial.rs Normal file
View File

@@ -0,0 +1,952 @@
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use winter_utils::{Deserializable, Serializable};
use super::{MmrDelta, MmrProof, Rpo256, RpoDigest};
use crate::merkle::{
mmr::{leaf_to_corresponding_tree, nodes_in_forest},
InOrderIndex, InnerNodeInfo, MerklePath, MmrError, MmrPeaks,
};
// TYPE ALIASES
// ================================================================================================
type NodeMap = BTreeMap<InOrderIndex, RpoDigest>;
// PARTIAL MERKLE MOUNTAIN RANGE
// ================================================================================================
/// Partially materialized Merkle Mountain Range (MMR), used to efficiently store and update the
/// authentication paths for a subset of the elements in a full MMR.
///
/// This structure store only the authentication path for a value, the value itself is stored
/// separately.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartialMmr {
/// The version of the MMR.
///
/// This value serves the following purposes:
///
/// - The forest is a counter for the total number of elements in the MMR.
/// - Since the MMR is an append-only structure, every change to it causes a change to the
/// `forest`, so this value has a dual purpose as a version tag.
/// - The bits in the forest also corresponds to the count and size of every perfect binary
/// tree that composes the MMR structure, which server to compute indexes and perform
/// validation.
pub(crate) forest: usize,
/// The MMR peaks.
///
/// The peaks are used for two reasons:
///
/// 1. It authenticates the addition of an element to the [PartialMmr], ensuring only valid
/// elements are tracked.
/// 2. During a MMR update peaks can be merged by hashing the left and right hand sides. The
/// peaks are used as the left hand.
///
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
/// leaves, starting from the peak with most children, to the one with least.
pub(crate) peaks: Vec<RpoDigest>,
/// Authentication nodes used to construct merkle paths for a subset of the MMR's leaves.
///
/// This does not include the MMR's peaks nor the tracked nodes, only the elements required to
/// construct their authentication paths. This property is used to detect when elements can be
/// safely removed, because they are no longer required to authenticate any element in the
/// [PartialMmr].
///
/// The elements in the MMR are referenced using a in-order tree index. This indexing scheme
/// permits for easy computation of the relative nodes (left/right children, sibling, parent),
/// which is useful for traversal. The indexing is also stable, meaning that merges to the
/// trees in the MMR can be represented without rewrites of the indexes.
pub(crate) nodes: NodeMap,
/// Flag indicating if the odd element should be tracked.
///
/// This flag is necessary because the sibling of the odd doesn't exist yet, so it can not be
/// added into `nodes` to signal the value is being tracked.
pub(crate) track_latest: bool,
}
impl PartialMmr {
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [PartialMmr] instantiated from the specified peaks.
pub fn from_peaks(peaks: MmrPeaks) -> Self {
let forest = peaks.num_leaves();
let peaks = peaks.into();
let nodes = BTreeMap::new();
let track_latest = false;
Self { forest, peaks, nodes, track_latest }
}
/// Returns a new [PartialMmr] instantiated from the specified components.
///
/// This constructor does not check the consistency between peaks and nodes. If the specified
/// peaks are nodes are inconsistent, the returned partial MMR may exhibit undefined behavior.
pub fn from_parts(peaks: MmrPeaks, nodes: NodeMap, track_latest: bool) -> Self {
let forest = peaks.num_leaves();
let peaks = peaks.into();
Self { forest, peaks, nodes, track_latest }
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the current `forest` of this [PartialMmr].
///
/// This value corresponds to the version of the [PartialMmr] and the number of leaves in the
/// underlying MMR.
pub fn forest(&self) -> usize {
self.forest
}
/// Returns the number of leaves in the underlying MMR for this [PartialMmr].
pub fn num_leaves(&self) -> usize {
self.forest
}
/// Returns the peaks of the MMR for this [PartialMmr].
pub fn peaks(&self) -> MmrPeaks {
// expect() is OK here because the constructor ensures that MMR peaks can be constructed
// correctly
MmrPeaks::new(self.forest, self.peaks.clone()).expect("invalid MMR peaks")
}
/// Returns true if this partial MMR tracks an authentication path for the leaf at the
/// specified position.
pub fn is_tracked(&self, pos: usize) -> bool {
if pos >= self.forest {
return false;
} else if pos == self.forest - 1 && self.forest & 1 != 0 {
// if the number of leaves in the MMR is odd and the position is for the last leaf
// whether the leaf is tracked is defined by the `track_latest` flag
return self.track_latest;
}
let leaf_index = InOrderIndex::from_leaf_pos(pos);
self.is_tracked_node(&leaf_index)
}
/// Given a leaf position, returns the Merkle path to its corresponding peak, or None if this
/// partial MMR does not track an authentication paths for the specified leaf.
///
/// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were
/// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element
/// has position 0, the second position 1, and so on.
///
/// # Errors
/// Returns an error if the specified position is greater-or-equal than the number of leaves
/// in the underlying MMR.
pub fn open(&self, pos: usize) -> Result<Option<MmrProof>, MmrError> {
let tree_bit =
leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::PositionNotFound(pos))?;
let depth = tree_bit as usize;
let mut nodes = Vec::with_capacity(depth);
let mut idx = InOrderIndex::from_leaf_pos(pos);
while let Some(node) = self.nodes.get(&idx.sibling()) {
nodes.push(*node);
idx = idx.parent();
}
// If there are nodes then the path must be complete, otherwise it is a bug
debug_assert!(nodes.is_empty() || nodes.len() == depth);
if nodes.len() != depth {
// The requested `pos` is not being tracked.
Ok(None)
} else {
Ok(Some(MmrProof {
forest: self.forest,
position: pos,
merkle_path: MerklePath::new(nodes),
}))
}
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator nodes of all authentication paths of this [PartialMmr].
pub fn nodes(&self) -> impl Iterator<Item = (&InOrderIndex, &RpoDigest)> {
self.nodes.iter()
}
/// Returns an iterator over inner nodes of this [PartialMmr] for the specified leaves.
///
/// The order of iteration is not defined. If a leaf is not presented in this partial MMR it
/// is silently ignored.
pub fn inner_nodes<'a, I: Iterator<Item = (usize, RpoDigest)> + 'a>(
&'a self,
mut leaves: I,
) -> impl Iterator<Item = InnerNodeInfo> + 'a {
let stack = if let Some((pos, leaf)) = leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos);
vec![(idx, leaf)]
} else {
Vec::new()
};
InnerNodeIterator {
nodes: &self.nodes,
leaves,
stack,
seen_nodes: BTreeSet::new(),
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Adds a new peak and optionally track it. Returns a vector of the authentication nodes
/// inserted into this [PartialMmr] as a result of this operation.
///
/// When `track` is `true` the new leaf is tracked.
pub fn add(&mut self, leaf: RpoDigest, track: bool) -> Vec<(InOrderIndex, RpoDigest)> {
self.forest += 1;
let merges = self.forest.trailing_zeros() as usize;
let mut new_nodes = Vec::with_capacity(merges);
let peak = if merges == 0 {
self.track_latest = track;
leaf
} else {
let mut track_right = track;
let mut track_left = self.track_latest;
let mut right = leaf;
let mut right_idx = forest_to_rightmost_index(self.forest);
for _ in 0..merges {
let left = self.peaks.pop().expect("Missing peak");
let left_idx = right_idx.sibling();
if track_right {
let old = self.nodes.insert(left_idx, left);
new_nodes.push((left_idx, left));
debug_assert!(
old.is_none(),
"Idx {:?} already contained an element {:?}",
left_idx,
old
);
};
if track_left {
let old = self.nodes.insert(right_idx, right);
new_nodes.push((right_idx, right));
debug_assert!(
old.is_none(),
"Idx {:?} already contained an element {:?}",
right_idx,
old
);
};
// Update state for the next iteration.
// --------------------------------------------------------------------------------
// This layer is merged, go up one layer.
right_idx = right_idx.parent();
// Merge the current layer. The result is either the right element of the next
// merge, or a new peak.
right = Rpo256::merge(&[left, right]);
// This iteration merged the left and right nodes, the new value is always used as
// the next iteration's right node. Therefore the tracking flags of this iteration
// have to be merged into the right side only.
track_right = track_right || track_left;
// On the next iteration, a peak will be merged. If any of its children are tracked,
// then we have to track the left side
track_left = self.is_tracked_node(&right_idx.sibling());
}
right
};
self.peaks.push(peak);
new_nodes
}
/// Adds the authentication path represented by [MerklePath] if it is valid.
///
/// The `leaf_pos` refers to the global position of the leaf in the MMR, these are 0-indexed
/// values assigned in a strictly monotonic fashion as elements are inserted into the MMR,
/// this value corresponds to the values used in the MMR structure.
///
/// The `leaf` corresponds to the value at `leaf_pos`, and `path` is the authentication path for
/// that element up to its corresponding Mmr peak. The `leaf` is only used to compute the root
/// from the authentication path to valid the data, only the authentication data is saved in
/// the structure. If the value is required it should be stored out-of-band.
pub fn track(
&mut self,
leaf_pos: usize,
leaf: RpoDigest,
path: &MerklePath,
) -> Result<(), MmrError> {
// Checks there is a tree with same depth as the authentication path, if not the path is
// invalid.
let tree = 1 << path.depth();
if tree & self.forest == 0 {
return Err(MmrError::UnknownPeak(path.depth()));
};
if leaf_pos + 1 == self.forest
&& path.depth() == 0
&& self.peaks.last().map_or(false, |v| *v == leaf)
{
self.track_latest = true;
return Ok(());
}
// ignore the trees smaller than the target (these elements are position after the current
// target and don't affect the target leaf_pos)
let target_forest = self.forest ^ (self.forest & (tree - 1));
let peak_pos = (target_forest.count_ones() - 1) as usize;
// translate from mmr leaf_pos to merkle path
let path_idx = leaf_pos - (target_forest ^ tree);
// Compute the root of the authentication path, and check it matches the current version of
// the PartialMmr.
let computed = path
.compute_root(path_idx as u64, leaf)
.map_err(MmrError::MerkleRootComputationFailed)?;
if self.peaks[peak_pos] != computed {
return Err(MmrError::PeakPathMismatch);
}
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
for leaf in path.nodes() {
self.nodes.insert(idx.sibling(), *leaf);
idx = idx.parent();
}
Ok(())
}
/// Removes a leaf of the [PartialMmr] and the unused nodes from the authentication path.
///
/// Note: `leaf_pos` corresponds to the position in the MMR and not on an individual tree.
pub fn untrack(&mut self, leaf_pos: usize) {
let mut idx = InOrderIndex::from_leaf_pos(leaf_pos);
self.nodes.remove(&idx.sibling());
// `idx` represent the element that can be computed by the authentication path, because
// these elements can be computed they are not saved for the authentication of the current
// target. In other words, if the idx is present it was added for the authentication of
// another element, and no more elements should be removed otherwise it would remove that
// element's authentication data.
while !self.nodes.contains_key(&idx) {
idx = idx.parent();
self.nodes.remove(&idx.sibling());
}
}
/// Applies updates to this [PartialMmr] and returns a vector of new authentication nodes
/// inserted into the partial MMR.
pub fn apply(&mut self, delta: MmrDelta) -> Result<Vec<(InOrderIndex, RpoDigest)>, MmrError> {
if delta.forest < self.forest {
return Err(MmrError::InvalidPeaks(format!(
"forest of mmr delta {} is less than current forest {}",
delta.forest, self.forest
)));
}
let mut inserted_nodes = Vec::new();
if delta.forest == self.forest {
if !delta.data.is_empty() {
return Err(MmrError::InvalidUpdate);
}
return Ok(inserted_nodes);
}
// find the tree merges
let changes = self.forest ^ delta.forest;
let largest = 1 << changes.ilog2();
let merges = self.forest & (largest - 1);
debug_assert!(
!self.track_latest || (merges & 1) == 1,
"if there is an odd element, a merge is required"
);
// count the number elements needed to produce largest from the current state
let (merge_count, new_peaks) = if merges != 0 {
let depth = largest.trailing_zeros();
let skipped = merges.trailing_zeros();
let computed = merges.count_ones() - 1;
let merge_count = depth - skipped - computed;
let new_peaks = delta.forest & (largest - 1);
(merge_count, new_peaks)
} else {
(0, changes)
};
// verify the delta size
if (delta.data.len() as u32) != merge_count + new_peaks.count_ones() {
return Err(MmrError::InvalidUpdate);
}
// keeps track of how many data elements from the update have been consumed
let mut update_count = 0;
if merges != 0 {
// starts at the smallest peak and follows the merged peaks
let mut peak_idx = forest_to_root_index(self.forest);
// match order of the update data while applying it
self.peaks.reverse();
// set to true when the data is needed for authentication paths updates
let mut track = self.track_latest;
self.track_latest = false;
let mut peak_count = 0;
let mut target = 1 << merges.trailing_zeros();
let mut new = delta.data[0];
update_count += 1;
while target < largest {
// check if either the left or right subtrees have saved for authentication paths.
// If so, turn tracking on to update those paths.
if target != 1 && !track {
track = self.is_tracked_node(&peak_idx);
}
// update data only contains the nodes from the right subtrees, left nodes are
// either previously known peaks or computed values
let (left, right) = if target & merges != 0 {
let peak = self.peaks[peak_count];
let sibling_idx = peak_idx.sibling();
// if the sibling peak is tracked, add this peaks to the set of
// authentication nodes
if self.is_tracked_node(&sibling_idx) {
self.nodes.insert(peak_idx, new);
inserted_nodes.push((peak_idx, new));
}
peak_count += 1;
(peak, new)
} else {
let update = delta.data[update_count];
update_count += 1;
(new, update)
};
if track {
let sibling_idx = peak_idx.sibling();
if peak_idx.is_left_child() {
self.nodes.insert(sibling_idx, right);
inserted_nodes.push((sibling_idx, right));
} else {
self.nodes.insert(sibling_idx, left);
inserted_nodes.push((sibling_idx, left));
}
}
peak_idx = peak_idx.parent();
new = Rpo256::merge(&[left, right]);
target <<= 1;
}
debug_assert!(peak_count == (merges.count_ones() as usize));
// restore the peaks order
self.peaks.reverse();
// remove the merged peaks
self.peaks.truncate(self.peaks.len() - peak_count);
// add the newly computed peak, the result of the merges
self.peaks.push(new);
}
// The rest of the update data is composed of peaks. None of these elements can contain
// tracked elements because the peaks were unknown, and it is not possible to add elements
// for tacking without authenticating it to a peak.
self.peaks.extend_from_slice(&delta.data[update_count..]);
self.forest = delta.forest;
debug_assert!(self.peaks.len() == (self.forest.count_ones() as usize));
Ok(inserted_nodes)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
/// Returns true if this [PartialMmr] tracks authentication path for the node at the specified
/// index.
fn is_tracked_node(&self, node_index: &InOrderIndex) -> bool {
if node_index.is_leaf() {
self.nodes.contains_key(&node_index.sibling())
} else {
let left_child = node_index.left_child();
let right_child = node_index.right_child();
self.nodes.contains_key(&left_child) | self.nodes.contains_key(&right_child)
}
}
}
// CONVERSIONS
// ================================================================================================
impl From<MmrPeaks> for PartialMmr {
fn from(peaks: MmrPeaks) -> Self {
Self::from_peaks(peaks)
}
}
impl From<PartialMmr> for MmrPeaks {
fn from(partial_mmr: PartialMmr) -> Self {
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
// matches the number of peaks, as required by the [MmrPeaks]
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks).unwrap()
}
}
impl From<&MmrPeaks> for PartialMmr {
fn from(peaks: &MmrPeaks) -> Self {
Self::from_peaks(peaks.clone())
}
}
impl From<&PartialMmr> for MmrPeaks {
fn from(partial_mmr: &PartialMmr) -> Self {
// Safety: the [PartialMmr] maintains the constraints the number of true bits in the forest
// matches the number of peaks, as required by the [MmrPeaks]
MmrPeaks::new(partial_mmr.forest, partial_mmr.peaks.clone()).unwrap()
}
}
// ITERATORS
// ================================================================================================
/// An iterator over every inner node of the [PartialMmr].
pub struct InnerNodeIterator<'a, I: Iterator<Item = (usize, RpoDigest)>> {
nodes: &'a NodeMap,
leaves: I,
stack: Vec<(InOrderIndex, RpoDigest)>,
seen_nodes: BTreeSet<InOrderIndex>,
}
impl<I: Iterator<Item = (usize, RpoDigest)>> Iterator for InnerNodeIterator<'_, I> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
while let Some((idx, node)) = self.stack.pop() {
let parent_idx = idx.parent();
let new_node = self.seen_nodes.insert(parent_idx);
// if we haven't seen this node's parent before, and the node has a sibling, return
// the inner node defined by the parent of this node, and move up the branch
if new_node {
if let Some(sibling) = self.nodes.get(&idx.sibling()) {
let (left, right) = if parent_idx.left_child() == idx {
(node, *sibling)
} else {
(*sibling, node)
};
let parent = Rpo256::merge(&[left, right]);
let inner_node = InnerNodeInfo { value: parent, left, right };
self.stack.push((parent_idx, parent));
return Some(inner_node);
}
}
// the previous leaf has been processed, try to process the next leaf
if let Some((pos, leaf)) = self.leaves.next() {
let idx = InOrderIndex::from_leaf_pos(pos);
self.stack.push((idx, leaf));
}
}
None
}
}
impl Serializable for PartialMmr {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.forest.write_into(target);
self.peaks.write_into(target);
self.nodes.write_into(target);
target.write_bool(self.track_latest);
}
}
impl Deserializable for PartialMmr {
fn read_from<R: winter_utils::ByteReader>(
source: &mut R,
) -> Result<Self, winter_utils::DeserializationError> {
let forest = usize::read_from(source)?;
let peaks = Vec::<RpoDigest>::read_from(source)?;
let nodes = NodeMap::read_from(source)?;
let track_latest = source.read_bool()?;
Ok(Self { forest, peaks, nodes, track_latest })
}
}
// UTILS
// ================================================================================================
/// Given the description of a `forest`, returns the index of the root element of the smallest tree
/// in it.
fn forest_to_root_index(forest: usize) -> InOrderIndex {
// Count total size of all trees in the forest.
let nodes = nodes_in_forest(forest);
// Add the count for the parent nodes that separate each tree. These are allocated but
// currently empty, and correspond to the nodes that will be used once the trees are merged.
let open_trees = (forest.count_ones() - 1) as usize;
// Remove the count of the right subtree of the target tree, target tree root index comes
// before the subtree for the in-order tree walk.
let right_subtree_count = ((1u32 << forest.trailing_zeros()) - 1) as usize;
let idx = nodes + open_trees - right_subtree_count;
InOrderIndex::new(idx.try_into().unwrap())
}
/// Given the description of a `forest`, returns the index of the right most element.
fn forest_to_rightmost_index(forest: usize) -> InOrderIndex {
// Count total size of all trees in the forest.
let nodes = nodes_in_forest(forest);
// Add the count for the parent nodes that separate each tree. These are allocated but
// currently empty, and correspond to the nodes that will be used once the trees are merged.
let open_trees = (forest.count_ones() - 1) as usize;
let idx = nodes + open_trees;
InOrderIndex::new(idx.try_into().unwrap())
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use alloc::{collections::BTreeSet, vec::Vec};
use winter_utils::{Deserializable, Serializable};
use super::{
forest_to_rightmost_index, forest_to_root_index, InOrderIndex, MmrPeaks, PartialMmr,
RpoDigest,
};
use crate::merkle::{int_to_node, MerkleStore, Mmr, NodeIndex};
const LEAVES: [RpoDigest; 7] = [
int_to_node(0),
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
];
#[test]
fn test_forest_to_root_index() {
fn idx(pos: usize) -> InOrderIndex {
InOrderIndex::new(pos.try_into().unwrap())
}
// When there is a single tree in the forest, the index is equivalent to the number of
// leaves in that tree, which is `2^n`.
assert_eq!(forest_to_root_index(0b0001), idx(1));
assert_eq!(forest_to_root_index(0b0010), idx(2));
assert_eq!(forest_to_root_index(0b0100), idx(4));
assert_eq!(forest_to_root_index(0b1000), idx(8));
assert_eq!(forest_to_root_index(0b0011), idx(5));
assert_eq!(forest_to_root_index(0b0101), idx(9));
assert_eq!(forest_to_root_index(0b1001), idx(17));
assert_eq!(forest_to_root_index(0b0111), idx(13));
assert_eq!(forest_to_root_index(0b1011), idx(21));
assert_eq!(forest_to_root_index(0b1111), idx(29));
assert_eq!(forest_to_root_index(0b0110), idx(10));
assert_eq!(forest_to_root_index(0b1010), idx(18));
assert_eq!(forest_to_root_index(0b1100), idx(20));
assert_eq!(forest_to_root_index(0b1110), idx(26));
}
#[test]
fn test_forest_to_rightmost_index() {
fn idx(pos: usize) -> InOrderIndex {
InOrderIndex::new(pos.try_into().unwrap())
}
for forest in 1..256 {
assert!(forest_to_rightmost_index(forest).inner() % 2 == 1, "Leaves are always odd");
}
assert_eq!(forest_to_rightmost_index(0b0001), idx(1));
assert_eq!(forest_to_rightmost_index(0b0010), idx(3));
assert_eq!(forest_to_rightmost_index(0b0011), idx(5));
assert_eq!(forest_to_rightmost_index(0b0100), idx(7));
assert_eq!(forest_to_rightmost_index(0b0101), idx(9));
assert_eq!(forest_to_rightmost_index(0b0110), idx(11));
assert_eq!(forest_to_rightmost_index(0b0111), idx(13));
assert_eq!(forest_to_rightmost_index(0b1000), idx(15));
assert_eq!(forest_to_rightmost_index(0b1001), idx(17));
assert_eq!(forest_to_rightmost_index(0b1010), idx(19));
assert_eq!(forest_to_rightmost_index(0b1011), idx(21));
assert_eq!(forest_to_rightmost_index(0b1100), idx(23));
assert_eq!(forest_to_rightmost_index(0b1101), idx(25));
assert_eq!(forest_to_rightmost_index(0b1110), idx(27));
assert_eq!(forest_to_rightmost_index(0b1111), idx(29));
}
#[test]
fn test_partial_mmr_apply_delta() {
// build an MMR with 10 nodes (2 peaks) and a partial MMR based on it
let mut mmr = Mmr::default();
(0..10).for_each(|i| mmr.add(int_to_node(i)));
let mut partial_mmr: PartialMmr = mmr.peaks().into();
// add authentication path for position 1 and 8
{
let node = mmr.get(1).unwrap();
let proof = mmr.open(1).unwrap();
partial_mmr.track(1, node, &proof.merkle_path).unwrap();
}
{
let node = mmr.get(8).unwrap();
let proof = mmr.open(8).unwrap();
partial_mmr.track(8, node, &proof.merkle_path).unwrap();
}
// add 2 more nodes into the MMR and validate apply_delta()
(10..12).for_each(|i| mmr.add(int_to_node(i)));
validate_apply_delta(&mmr, &mut partial_mmr);
// add 1 more node to the MMR, validate apply_delta() and start tracking the node
mmr.add(int_to_node(12));
validate_apply_delta(&mmr, &mut partial_mmr);
{
let node = mmr.get(12).unwrap();
let proof = mmr.open(12).unwrap();
partial_mmr.track(12, node, &proof.merkle_path).unwrap();
assert!(partial_mmr.track_latest);
}
// by this point we are tracking authentication paths for positions: 1, 8, and 12
// add 3 more nodes to the MMR (collapses to 1 peak) and validate apply_delta()
(13..16).for_each(|i| mmr.add(int_to_node(i)));
validate_apply_delta(&mmr, &mut partial_mmr);
}
fn validate_apply_delta(mmr: &Mmr, partial: &mut PartialMmr) {
let tracked_leaves = partial
.nodes
.iter()
.filter_map(|(index, _)| if index.is_leaf() { Some(index.sibling()) } else { None })
.collect::<Vec<_>>();
let nodes_before = partial.nodes.clone();
// compute and apply delta
let delta = mmr.get_delta(partial.forest(), mmr.forest()).unwrap();
let nodes_delta = partial.apply(delta).unwrap();
// new peaks were computed correctly
assert_eq!(mmr.peaks(), partial.peaks());
let mut expected_nodes = nodes_before;
for (key, value) in nodes_delta {
// nodes should not be duplicated
assert!(expected_nodes.insert(key, value).is_none());
}
// new nodes should be a combination of original nodes and delta
assert_eq!(expected_nodes, partial.nodes);
// make sure tracked leaves open to the same proofs as in the underlying MMR
for index in tracked_leaves {
let index_value: u64 = index.into();
let pos = index_value / 2;
let proof1 = partial.open(pos as usize).unwrap().unwrap();
let proof2 = mmr.open(pos as usize).unwrap();
assert_eq!(proof1, proof2);
}
}
#[test]
fn test_partial_mmr_inner_nodes_iterator() {
// build the MMR
let mmr: Mmr = LEAVES.into();
let first_peak = mmr.peaks().peaks()[0];
// -- test single tree ----------------------------
// get path and node for position 1
let node1 = mmr.get(1).unwrap();
let proof1 = mmr.open(1).unwrap();
// create partial MMR and add authentication path to node at position 1
let mut partial_mmr: PartialMmr = mmr.peaks().into();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
// empty iterator should have no nodes
assert_eq!(partial_mmr.inner_nodes([].iter().cloned()).next(), None);
// build Merkle store from authentication paths in partial MMR
let mut store: MerkleStore = MerkleStore::new();
store.extend(partial_mmr.inner_nodes([(1, node1)].iter().cloned()));
let index1 = NodeIndex::new(2, 1).unwrap();
let path1 = store.get_path(first_peak, index1).unwrap().path;
assert_eq!(path1, proof1.merkle_path);
// -- test no duplicates --------------------------
// build the partial MMR
let mut partial_mmr: PartialMmr = mmr.peaks().into();
let node0 = mmr.get(0).unwrap();
let proof0 = mmr.open(0).unwrap();
let node2 = mmr.get(2).unwrap();
let proof2 = mmr.open(2).unwrap();
partial_mmr.track(0, node0, &proof0.merkle_path).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(2, node2, &proof2.merkle_path).unwrap();
// make sure there are no duplicates
let leaves = [(0, node0), (1, node1), (2, node2)];
let mut nodes = BTreeSet::new();
for node in partial_mmr.inner_nodes(leaves.iter().cloned()) {
assert!(nodes.insert(node.value));
}
// and also that the store is still be built correctly
store.extend(partial_mmr.inner_nodes(leaves.iter().cloned()));
let index0 = NodeIndex::new(2, 0).unwrap();
let index1 = NodeIndex::new(2, 1).unwrap();
let index2 = NodeIndex::new(2, 2).unwrap();
let path0 = store.get_path(first_peak, index0).unwrap().path;
let path1 = store.get_path(first_peak, index1).unwrap().path;
let path2 = store.get_path(first_peak, index2).unwrap().path;
assert_eq!(path0, proof0.merkle_path);
assert_eq!(path1, proof1.merkle_path);
assert_eq!(path2, proof2.merkle_path);
// -- test multiple trees -------------------------
// build the partial MMR
let mut partial_mmr: PartialMmr = mmr.peaks().into();
let node5 = mmr.get(5).unwrap();
let proof5 = mmr.open(5).unwrap();
partial_mmr.track(1, node1, &proof1.merkle_path).unwrap();
partial_mmr.track(5, node5, &proof5.merkle_path).unwrap();
// build Merkle store from authentication paths in partial MMR
let mut store: MerkleStore = MerkleStore::new();
store.extend(partial_mmr.inner_nodes([(1, node1), (5, node5)].iter().cloned()));
let index1 = NodeIndex::new(2, 1).unwrap();
let index5 = NodeIndex::new(1, 1).unwrap();
let second_peak = mmr.peaks().peaks()[1];
let path1 = store.get_path(first_peak, index1).unwrap().path;
let path5 = store.get_path(second_peak, index5).unwrap().path;
assert_eq!(path1, proof1.merkle_path);
assert_eq!(path5, proof5.merkle_path);
}
#[test]
fn test_partial_mmr_add_without_track() {
let mut mmr = Mmr::default();
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
for el in (0..256).map(int_to_node) {
mmr.add(el);
partial_mmr.add(el, false);
assert_eq!(mmr.peaks(), partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest());
}
}
#[test]
fn test_partial_mmr_add_with_track() {
let mut mmr = Mmr::default();
let empty_peaks = MmrPeaks::new(0, vec![]).unwrap();
let mut partial_mmr = PartialMmr::from_peaks(empty_peaks);
for i in 0..256 {
let el = int_to_node(i);
mmr.add(el);
partial_mmr.add(el, true);
assert_eq!(mmr.peaks(), partial_mmr.peaks());
assert_eq!(mmr.forest(), partial_mmr.forest());
for pos in 0..i {
let mmr_proof = mmr.open(pos as usize).unwrap();
let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap();
assert_eq!(mmr_proof, partialmmr_proof);
}
}
}
#[test]
fn test_partial_mmr_add_existing_track() {
let mut mmr = Mmr::from((0..7).map(int_to_node));
// derive a partial Mmr from it which tracks authentication path to leaf 5
let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks());
let path_to_5 = mmr.open(5).unwrap().merkle_path;
let leaf_at_5 = mmr.get(5).unwrap();
partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap();
// add a new leaf to both Mmr and partial Mmr
let leaf_at_7 = int_to_node(7);
mmr.add(leaf_at_7);
partial_mmr.add(leaf_at_7, false);
// the openings should be the same
assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap());
}
#[test]
fn test_partial_mmr_serialization() {
let mmr = Mmr::from((0..7).map(int_to_node));
let partial_mmr = PartialMmr::from_peaks(mmr.peaks());
let bytes = partial_mmr.to_bytes();
let decoded = PartialMmr::read_from_bytes(&bytes).unwrap();
assert_eq!(partial_mmr, decoded);
}
}

162
src/merkle/mmr/peaks.rs Normal file
View File

@@ -0,0 +1,162 @@
use alloc::vec::Vec;
use super::{super::ZERO, Felt, MmrError, MmrProof, Rpo256, RpoDigest, Word};
// MMR PEAKS
// ================================================================================================
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct MmrPeaks {
/// The number of leaves is used to differentiate MMRs that have the same number of peaks. This
/// happens because the number of peaks goes up-and-down as the structure is used causing
/// existing trees to be merged and new ones to be created. As an example, every time the MMR
/// has a power-of-two number of leaves there is a single peak.
///
/// Every tree in the MMR forest has a distinct power-of-two size, this means only the right-
/// most tree can have an odd number of elements (e.g. `1`). Additionally this means that the
/// bits in `num_leaves` conveniently encode the size of each individual tree.
///
/// Examples:
///
/// - With 5 leaves, the binary `0b101`. The number of set bits is equal the number of peaks,
/// in this case there are 2 peaks. The 0-indexed least-significant position of the bit
/// determines the number of elements of a tree, so the rightmost tree has `2**0` elements
/// and the left most has `2**2`.
/// - With 12 leaves, the binary is `0b1100`, this case also has 2 peaks, the leftmost tree has
/// `2**3=8` elements, and the right most has `2**2=4` elements.
num_leaves: usize,
/// All the peaks of every tree in the MMR forest. The peaks are always ordered by number of
/// leaves, starting from the peak with most children, to the one with least.
///
/// Invariant: The length of `peaks` must be equal to the number of true bits in `num_leaves`.
peaks: Vec<RpoDigest>,
}
impl MmrPeaks {
// CONSTRUCTOR
// --------------------------------------------------------------------------------------------
/// Returns new [MmrPeaks] instantiated from the provided vector of peaks and the number of
/// leaves in the underlying MMR.
///
/// # Errors
/// Returns an error if the number of leaves and the number of peaks are inconsistent.
pub fn new(num_leaves: usize, peaks: Vec<RpoDigest>) -> Result<Self, MmrError> {
if num_leaves.count_ones() as usize != peaks.len() {
return Err(MmrError::InvalidPeaks(format!(
"number of one bits in leaves is {} which does not equal peak length {}",
num_leaves.count_ones(),
peaks.len()
)));
}
Ok(Self { num_leaves, peaks })
}
// ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns a count of leaves in the underlying MMR.
pub fn num_leaves(&self) -> usize {
self.num_leaves
}
/// Returns the number of peaks of the underlying MMR.
pub fn num_peaks(&self) -> usize {
self.peaks.len()
}
/// Returns the list of peaks of the underlying MMR.
pub fn peaks(&self) -> &[RpoDigest] {
&self.peaks
}
/// Returns the peak by the provided index.
///
/// # Errors
/// Returns an error if the provided peak index is greater or equal to the current number of
/// peaks in the Mmr.
pub fn get_peak(&self, peak_idx: usize) -> Result<&RpoDigest, MmrError> {
self.peaks
.get(peak_idx)
.ok_or(MmrError::PeakOutOfBounds { peak_idx, peaks_len: self.peaks.len() })
}
/// Converts this [MmrPeaks] into its components: number of leaves and a vector of peaks of
/// the underlying MMR.
pub fn into_parts(self) -> (usize, Vec<RpoDigest>) {
(self.num_leaves, self.peaks)
}
/// Hashes the peaks.
///
/// The procedure will:
/// - Flatten and pad the peaks to a vector of Felts.
/// - Hash the vector of Felts.
pub fn hash_peaks(&self) -> RpoDigest {
Rpo256::hash_elements(&self.flatten_and_pad_peaks())
}
/// Verifies the Merkle opening proof.
///
/// # Errors
/// Returns an error if:
/// - provided opening proof is invalid.
/// - Mmr root value computed using the provided leaf value differs from the actual one.
pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> Result<(), MmrError> {
let root = self.get_peak(opening.peak_index())?;
opening
.merkle_path
.verify(opening.relative_pos() as u64, value, root)
.map_err(MmrError::InvalidMerklePath)
}
/// Flattens and pads the peaks to make hashing inside of the Miden VM easier.
///
/// The procedure will:
/// - Flatten the vector of Words into a vector of Felts.
/// - Pad the peaks with ZERO to an even number of words, this removes the need to handle RPO
/// padding.
/// - Pad the peaks to a minimum length of 16 words, which reduces the constant cost of hashing.
pub fn flatten_and_pad_peaks(&self) -> Vec<Felt> {
let num_peaks = self.peaks.len();
// To achieve the padding rules above we calculate the length of the final vector.
// This is calculated as the number of field elements. Each peak is 4 field elements.
// The length is calculated as follows:
// - If there are less than 16 peaks, the data is padded to 16 peaks and as such requires 64
// field elements.
// - If there are more than 16 peaks and the number of peaks is odd, the data is padded to
// an even number of peaks and as such requires `(num_peaks + 1) * 4` field elements.
// - If there are more than 16 peaks and the number of peaks is even, the data is not padded
// and as such requires `num_peaks * 4` field elements.
let len = if num_peaks < 16 {
64
} else if num_peaks % 2 == 1 {
(num_peaks + 1) * 4
} else {
num_peaks * 4
};
let mut elements = Vec::with_capacity(len);
elements.extend_from_slice(
&self
.peaks
.as_slice()
.iter()
.map(|digest| digest.into())
.collect::<Vec<Word>>()
.concat(),
);
elements.resize(len, ZERO);
elements
}
}
impl From<MmrPeaks> for Vec<RpoDigest> {
fn from(peaks: MmrPeaks) -> Self {
peaks.peaks
}
}

View File

@@ -1,6 +1,9 @@
/// The representation of a single Merkle path.
use super::super::MerklePath;
use super::full::{high_bitmask, leaf_to_corresponding_tree};
use super::{full::high_bitmask, leaf_to_corresponding_tree};
// MMR PROOF
// ================================================================================================
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
@@ -26,9 +29,78 @@ impl MmrProof {
self.position - forest_before
}
/// Returns index of the MMR peak against which the Merkle path in this proof can be verified.
pub fn peak_index(&self) -> usize {
let root = leaf_to_corresponding_tree(self.position, self.forest)
.expect("position must be part of the forest");
(self.forest.count_ones() - root - 1) as usize
let smaller_peak_mask = 2_usize.pow(root) as usize - 1;
let num_smaller_peaks = (self.forest & smaller_peak_mask).count_ones();
(self.forest.count_ones() - num_smaller_peaks - 1) as usize
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{MerklePath, MmrProof};
#[test]
fn test_peak_index() {
// --- single peak forest ---------------------------------------------
let forest = 11;
// the first 4 leaves belong to peak 0
for position in 0..8 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// --- forest with non-consecutive peaks ------------------------------
let forest = 11;
// the first 8 leaves belong to peak 0
for position in 0..8 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// the next 2 leaves belong to peak 1
for position in 8..10 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 1);
}
// the last leaf is the peak 2
let proof = make_dummy_proof(forest, 10);
assert_eq!(proof.peak_index(), 2);
// --- forest with consecutive peaks ----------------------------------
let forest = 7;
// the first 4 leaves belong to peak 0
for position in 0..4 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 0);
}
// the next 2 leaves belong to peak 1
for position in 4..6 {
let proof = make_dummy_proof(forest, position);
assert_eq!(proof.peak_index(), 1);
}
// the last leaf is the peak 2
let proof = make_dummy_proof(forest, 6);
assert_eq!(proof.peak_index(), 2);
}
fn make_dummy_proof(forest: usize, position: usize) -> MmrProof {
MmrProof {
forest,
position,
merkle_path: MerklePath::default(),
}
}
}

View File

@@ -1,12 +1,13 @@
use alloc::vec::Vec;
use super::{
super::{InnerNodeInfo, Vec},
super::{InnerNodeInfo, Rpo256, RpoDigest},
bit::TrueBitPositionIterator,
full::{high_bitmask, leaf_to_corresponding_tree, nodes_in_forest},
Mmr, MmrPeaks, Rpo256,
full::high_bitmask,
leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr,
};
use crate::{
hash::rpo::RpoDigest,
merkle::{int_to_node, MerklePath},
merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex},
Felt, Word,
};
@@ -115,17 +116,18 @@ const LEAVES: [RpoDigest; 7] = [
#[test]
fn test_mmr_simple() {
let mut postorder = Vec::new();
postorder.push(LEAVES[0]);
postorder.push(LEAVES[1]);
postorder.push(Rpo256::merge(&[LEAVES[0], LEAVES[1]]));
postorder.push(LEAVES[2]);
postorder.push(LEAVES[3]);
postorder.push(Rpo256::merge(&[LEAVES[2], LEAVES[3]]));
postorder.push(Rpo256::merge(&[postorder[2], postorder[5]]));
let mut postorder = vec![
LEAVES[0],
LEAVES[1],
merge(LEAVES[0], LEAVES[1]),
LEAVES[2],
LEAVES[3],
merge(LEAVES[2], LEAVES[3]),
];
postorder.push(merge(postorder[2], postorder[5]));
postorder.push(LEAVES[4]);
postorder.push(LEAVES[5]);
postorder.push(Rpo256::merge(&[LEAVES[4], LEAVES[5]]));
postorder.push(merge(LEAVES[4], LEAVES[5]));
postorder.push(LEAVES[6]);
let mut mmr = Mmr::new();
@@ -137,70 +139,70 @@ fn test_mmr_simple() {
assert_eq!(mmr.nodes.len(), 1);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 1);
assert_eq!(acc.peaks, &[postorder[0]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 1);
assert_eq!(acc.peaks(), &[postorder[0]]);
mmr.add(LEAVES[1]);
assert_eq!(mmr.forest(), 2);
assert_eq!(mmr.nodes.len(), 3);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 2);
assert_eq!(acc.peaks, &[postorder[2]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 2);
assert_eq!(acc.peaks(), &[postorder[2]]);
mmr.add(LEAVES[2]);
assert_eq!(mmr.forest(), 3);
assert_eq!(mmr.nodes.len(), 4);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 3);
assert_eq!(acc.peaks, &[postorder[2], postorder[3]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 3);
assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]);
mmr.add(LEAVES[3]);
assert_eq!(mmr.forest(), 4);
assert_eq!(mmr.nodes.len(), 7);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 4);
assert_eq!(acc.peaks, &[postorder[6]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 4);
assert_eq!(acc.peaks(), &[postorder[6]]);
mmr.add(LEAVES[4]);
assert_eq!(mmr.forest(), 5);
assert_eq!(mmr.nodes.len(), 8);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 5);
assert_eq!(acc.peaks, &[postorder[6], postorder[7]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 5);
assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]);
mmr.add(LEAVES[5]);
assert_eq!(mmr.forest(), 6);
assert_eq!(mmr.nodes.len(), 10);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 6);
assert_eq!(acc.peaks, &[postorder[6], postorder[9]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 6);
assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]);
mmr.add(LEAVES[6]);
assert_eq!(mmr.forest(), 7);
assert_eq!(mmr.nodes.len(), 11);
assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]);
let acc = mmr.accumulator();
assert_eq!(acc.num_leaves, 7);
assert_eq!(acc.peaks, &[postorder[6], postorder[9], postorder[10]]);
let acc = mmr.peaks();
assert_eq!(acc.num_leaves(), 7);
assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]);
}
#[test]
fn test_mmr_open() {
let mmr: Mmr = LEAVES.into();
let h01 = Rpo256::merge(&[LEAVES[0], LEAVES[1]]);
let h23 = Rpo256::merge(&[LEAVES[2], LEAVES[3]]);
let h01 = merge(LEAVES[0], LEAVES[1]);
let h23 = merge(LEAVES[2], LEAVES[3]);
// node at pos 7 is the root
assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None");
@@ -213,10 +215,7 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, empty);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 6);
assert!(
mmr.accumulator().verify(LEAVES[6], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[6], opening).unwrap();
// nodes 4,5 are depth 1
let root_to_path = MerklePath::new(vec![LEAVES[4]]);
@@ -226,10 +225,7 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 5);
assert!(
mmr.accumulator().verify(LEAVES[5], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[5], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[5]]);
let opening = mmr
@@ -238,10 +234,7 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 4);
assert!(
mmr.accumulator().verify(LEAVES[4], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[4], opening).unwrap();
// nodes 0,1,2,3 are detph 2
let root_to_path = MerklePath::new(vec![LEAVES[2], h01]);
@@ -251,10 +244,7 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 3);
assert!(
mmr.accumulator().verify(LEAVES[3], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[3], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[3], h01]);
let opening = mmr
@@ -263,10 +253,7 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 2);
assert!(
mmr.accumulator().verify(LEAVES[2], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[2], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[0], h23]);
let opening = mmr
@@ -275,10 +262,7 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 1);
assert!(
mmr.accumulator().verify(LEAVES[1], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[1], opening).unwrap();
let root_to_path = MerklePath::new(vec![LEAVES[1], h23]);
let opening = mmr
@@ -287,10 +271,171 @@ fn test_mmr_open() {
assert_eq!(opening.merkle_path, root_to_path);
assert_eq!(opening.forest, mmr.forest);
assert_eq!(opening.position, 0);
assert!(
mmr.accumulator().verify(LEAVES[0], opening),
"MmrProof should be valid for the current accumulator."
);
mmr.peaks().verify(LEAVES[0], opening).unwrap();
}
#[test]
fn test_mmr_open_older_version() {
let mmr: Mmr = LEAVES.into();
fn is_even(v: &usize) -> bool {
v & 1 == 0
}
// merkle path of a node is empty if there are no elements to pair with it
for pos in (0..mmr.forest()).filter(is_even) {
let forest = pos + 1;
let proof = mmr.open_at(pos, forest).unwrap();
assert_eq!(proof.forest, forest);
assert_eq!(proof.merkle_path.nodes(), []);
assert_eq!(proof.position, pos);
}
// openings match that of a merkle tree
let mtree: MerkleTree = LEAVES[..4].try_into().unwrap();
for forest in 4..=LEAVES.len() {
for pos in 0..4 {
let idx = NodeIndex::new(2, pos).unwrap();
let path = mtree.get_path(idx).unwrap();
let proof = mmr.open_at(pos as usize, forest).unwrap();
assert_eq!(path, proof.merkle_path);
}
}
let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap();
for forest in 6..=LEAVES.len() {
for pos in 0..2 {
let idx = NodeIndex::new(1, pos).unwrap();
let path = mtree.get_path(idx).unwrap();
// account for the bigger tree with 4 elements
let mmr_pos = (pos + 4) as usize;
let proof = mmr.open_at(mmr_pos, forest).unwrap();
assert_eq!(path, proof.merkle_path);
}
}
}
/// Tests the openings of a simple Mmr with a single tree of depth 8.
#[test]
fn test_mmr_open_eight() {
let leaves = [
int_to_node(0),
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
];
let mtree: MerkleTree = leaves.as_slice().try_into().unwrap();
let forest = leaves.len();
let mmr: Mmr = leaves.into();
let root = mtree.root();
let position = 0;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 1;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 2;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 3;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 4;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 5;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 6;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
let position = 7;
let proof = mmr.open(position).unwrap();
let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root);
}
/// Tests the openings of Mmr with a 3 trees of depths 4, 2, and 1.
#[test]
fn test_mmr_open_seven() {
let mtree1: MerkleTree = LEAVES[..4].try_into().unwrap();
let mtree2: MerkleTree = LEAVES[4..6].try_into().unwrap();
let forest = LEAVES.len();
let mmr: Mmr = LEAVES.into();
let position = 0;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root());
let position = 1;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root());
let position = 2;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root());
let position = 3;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath =
mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root());
let position = 4;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root());
let position = 5;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root());
let position = 6;
let proof = mmr.open(position).unwrap();
let merkle_path: MerklePath = [].as_ref().into();
assert_eq!(proof, MmrProof { forest, position, merkle_path });
assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]);
}
#[test]
@@ -311,15 +456,16 @@ fn test_mmr_invariants() {
let mut mmr = Mmr::new();
for v in 1..=1028 {
mmr.add(int_to_node(v));
let accumulator = mmr.accumulator();
let accumulator = mmr.peaks();
assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add");
assert_eq!(
v as usize, accumulator.num_leaves,
v as usize,
accumulator.num_leaves(),
"MMR and its accumulator must match leaves count"
);
assert_eq!(
accumulator.num_leaves.count_ones() as usize,
accumulator.peaks.len(),
accumulator.num_leaves().count_ones() as usize,
accumulator.peaks().len(),
"bits on leaves must match the number of peaks"
);
@@ -391,10 +537,50 @@ fn test_mmr_inner_nodes() {
assert_eq!(postorder, nodes);
}
#[test]
fn test_mmr_peaks() {
let mmr: Mmr = LEAVES.into();
let forest = 0b0001;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[0]]);
let forest = 0b0010;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[2]]);
let forest = 0b0011;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]);
let forest = 0b0100;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6]]);
let forest = 0b0101;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]);
let forest = 0b0110;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]);
let forest = 0b0111;
let acc = mmr.peaks_at(forest).unwrap();
assert_eq!(acc.num_leaves(), forest);
assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]);
}
#[test]
fn test_mmr_hash_peaks() {
let mmr: Mmr = LEAVES.into();
let peaks = mmr.accumulator();
let peaks = mmr.peaks();
let first_peak = Rpo256::merge(&[
Rpo256::merge(&[LEAVES[0], LEAVES[1]]),
@@ -406,10 +592,7 @@ fn test_mmr_hash_peaks() {
// minimum length is 16
let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec();
expected_peaks.resize(16, RpoDigest::default());
assert_eq!(
peaks.hash_peaks(),
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
);
assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks)));
}
#[test]
@@ -418,17 +601,16 @@ fn test_mmr_peaks_hash_less_than_16() {
for i in 0..16 {
peaks.push(int_to_node(i));
let accumulator = MmrPeaks {
num_leaves: (1 << peaks.len()) - 1,
peaks: peaks.clone(),
};
let num_leaves = (1 << peaks.len()) - 1;
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
// minimum length is 16
let mut expected_peaks = peaks.clone();
expected_peaks.resize(16, RpoDigest::default());
assert_eq!(
accumulator.hash_peaks(),
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
);
}
}
@@ -437,24 +619,239 @@ fn test_mmr_peaks_hash_less_than_16() {
fn test_mmr_peaks_hash_odd() {
let peaks: Vec<_> = (0..=17).map(int_to_node).collect();
let accumulator = MmrPeaks {
num_leaves: (1 << peaks.len()) - 1,
peaks: peaks.clone(),
};
let num_leaves = (1 << peaks.len()) - 1;
let accumulator = MmrPeaks::new(num_leaves, peaks.clone()).unwrap();
// odd length bigger than 16 is padded to the next even number
let mut expected_peaks = peaks;
expected_peaks.resize(18, RpoDigest::default());
assert_eq!(
accumulator.hash_peaks(),
*Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
Rpo256::hash_elements(&digests_to_elements(&expected_peaks))
);
}
#[test]
fn test_mmr_delta() {
let mmr: Mmr = LEAVES.into();
let acc = mmr.peaks();
// original_forest can't have more elements
assert!(
mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(),
"Can not provide updates for a newer Mmr"
);
// if the number of elements is the same there is no change
assert!(
mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(),
"There are no updates for the same Mmr version"
);
// missing the last element added, which is itself a tree peak
assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak");
// missing the sibling to complete the tree of depth 2, and the last element
assert_eq!(
mmr.get_delta(5, mmr.forest()).unwrap().data,
vec![LEAVES[5], acc.peaks()[2]],
"one sibling, one peak"
);
// missing the whole last two trees, only send the peaks
assert_eq!(
mmr.get_delta(4, mmr.forest()).unwrap().data,
vec![acc.peaks()[1], acc.peaks()[2]],
"two peaks"
);
// missing the sibling to complete the first tree, and the two last trees
assert_eq!(
mmr.get_delta(3, mmr.forest()).unwrap().data,
vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]],
"one sibling, two peaks"
);
// missing half of the first tree, only send the computed element (not the leaves), and the new
// peaks
assert_eq!(
mmr.get_delta(2, mmr.forest()).unwrap().data,
vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
"one sibling, two peaks"
);
assert_eq!(
mmr.get_delta(1, mmr.forest()).unwrap().data,
vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]],
"one sibling, two peaks"
);
assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks");
}
#[test]
fn test_mmr_delta_old_forest() {
let mmr: Mmr = LEAVES.into();
// from_forest must be smaller-or-equal to to_forest
for version in 1..=mmr.forest() {
assert!(mmr.get_delta(version + 1, version).is_err());
}
// when from_forest and to_forest are equal, there are no updates
for version in 1..=mmr.forest() {
let delta = mmr.get_delta(version, version).unwrap();
assert!(delta.data.is_empty());
assert_eq!(delta.forest, version);
}
// test update which merges the odd peak to the right
for count in 0..(mmr.forest() / 2) {
// *2 because every iteration tests a pair
// +1 because the Mmr is 1-indexed
let from_forest = (count * 2) + 1;
let to_forest = (count * 2) + 2;
let delta = mmr.get_delta(from_forest, to_forest).unwrap();
// *2 because every iteration tests a pair
// +1 because sibling is the odd element
let sibling = (count * 2) + 1;
assert_eq!(delta.data, [LEAVES[sibling]]);
assert_eq!(delta.forest, to_forest);
}
let version = 4;
let delta = mmr.get_delta(1, version).unwrap();
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]);
assert_eq!(delta.forest, version);
let version = 5;
let delta = mmr.get_delta(1, version).unwrap();
assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]);
assert_eq!(delta.forest, version);
}
#[test]
fn test_partial_mmr_simple() {
let mmr: Mmr = LEAVES.into();
let peaks = mmr.peaks();
let mut partial: PartialMmr = peaks.clone().into();
// check initial state of the partial mmr
assert_eq!(partial.peaks(), peaks);
assert_eq!(partial.forest(), peaks.num_leaves());
assert_eq!(partial.forest(), LEAVES.len());
assert_eq!(partial.peaks().num_peaks(), 3);
assert_eq!(partial.nodes.len(), 0);
// check state after adding tracking one element
let proof1 = mmr.open(0).unwrap();
let el1 = mmr.get(proof1.position).unwrap();
partial.track(proof1.position, el1, &proof1.merkle_path).unwrap();
// check the number of nodes increased by the number of nodes in the proof
assert_eq!(partial.nodes.len(), proof1.merkle_path.len());
// check the values match
let idx = InOrderIndex::from_leaf_pos(proof1.position);
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[0]);
let idx = idx.parent();
assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]);
let proof2 = mmr.open(1).unwrap();
let el2 = mmr.get(proof2.position).unwrap();
partial.track(proof2.position, el2, &proof2.merkle_path).unwrap();
// check the number of nodes increased by a single element (the one that is not shared)
assert_eq!(partial.nodes.len(), 3);
// check the values match
let idx = InOrderIndex::from_leaf_pos(proof2.position);
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[0]);
let idx = idx.parent();
assert_eq!(partial.nodes[&idx.sibling()], proof2.merkle_path[1]);
}
#[test]
fn test_partial_mmr_update_single() {
let mut full = Mmr::new();
let zero = int_to_node(0);
full.add(zero);
let mut partial: PartialMmr = full.peaks().into();
let proof = full.open(0).unwrap();
partial.track(proof.position, zero, &proof.merkle_path).unwrap();
for i in 1..100 {
let node = int_to_node(i);
full.add(node);
let delta = full.get_delta(partial.forest(), full.forest()).unwrap();
partial.apply(delta).unwrap();
assert_eq!(partial.forest(), full.forest());
assert_eq!(partial.peaks(), full.peaks());
let proof1 = full.open(i as usize).unwrap();
partial.track(proof1.position, node, &proof1.merkle_path).unwrap();
let proof2 = partial.open(proof1.position).unwrap().unwrap();
assert_eq!(proof1.merkle_path, proof2.merkle_path);
}
}
#[test]
fn test_mmr_add_invalid_odd_leaf() {
let mmr: Mmr = LEAVES.into();
let acc = mmr.peaks();
let mut partial: PartialMmr = acc.clone().into();
let empty = MerklePath::new(Vec::new());
// None of the other leaves should work
for node in LEAVES.iter().cloned().rev().skip(1) {
let result = partial.track(LEAVES.len() - 1, node, &empty);
assert!(result.is_err());
}
let result = partial.track(LEAVES.len() - 1, LEAVES[6], &empty);
assert!(result.is_ok());
}
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
///
/// Here we manipulate the proof to return a peak index of 1 while the MMR only has 1 peak (with
/// index 0).
#[test]
#[should_panic]
fn test_mmr_proof_num_peaks_exceeds_current_num_peaks() {
let mmr: Mmr = LEAVES[0..4].iter().cloned().into();
let mut proof = mmr.open(3).unwrap();
proof.forest = 5;
proof.position = 4;
mmr.peaks().verify(LEAVES[3], proof).unwrap();
}
/// Tests that a proof whose peak count exceeds the peak count of the MMR returns an error.
///
/// We create an MmrProof for a leaf whose peak index to verify against is 1.
/// Then we add another leaf which results in an Mmr with just one peak due to trees
/// being merged. If we try to use the old proof against the new Mmr, we should get an error.
#[test]
#[should_panic]
fn test_mmr_old_proof_num_peaks_exceeds_current_num_peaks() {
let leaves_len = 3;
let mut mmr = Mmr::from(LEAVES[0..leaves_len].iter().cloned());
let leaf_idx = leaves_len - 1;
let proof = mmr.open(leaf_idx).unwrap();
assert!(mmr.peaks().verify(LEAVES[leaf_idx], proof.clone()).is_ok());
mmr.add(LEAVES[leaves_len]);
mmr.peaks().verify(LEAVES[leaf_idx], proof).unwrap();
}
mod property_tests {
use super::leaf_to_corresponding_tree;
use proptest::prelude::*;
use super::leaf_to_corresponding_tree;
proptest! {
#[test]
fn test_last_position_is_always_contained_in_the_last_tree(leaves in any::<usize>().prop_filter("cant have an empty tree", |v| *v != 0)) {
@@ -471,10 +868,10 @@ mod property_tests {
proptest! {
#[test]
fn test_contained_tree_is_always_power_of_two((leaves, pos) in any::<usize>().prop_flat_map(|v| (Just(v), 0..v))) {
let tree = leaf_to_corresponding_tree(pos, leaves).expect("pos is smaller than leaves, there should always be a corresponding tree");
let mask = 1usize << tree;
let tree_bit = leaf_to_corresponding_tree(pos, leaves).expect("pos is smaller than leaves, there should always be a corresponding tree");
let mask = 1usize << tree_bit;
assert!(tree < usize::BITS, "the result must be a bit in usize");
assert!(tree_bit < usize::BITS, "the result must be a bit in usize");
assert!(mask & leaves != 0, "the result should be a tree in leaves");
}
}
@@ -486,3 +883,8 @@ mod property_tests {
fn digests_to_elements(digests: &[RpoDigest]) -> Vec<Felt> {
digests.iter().flat_map(Word::from).collect()
}
// short hand for the rpo hash, used to make test code more concise and easy to read
fn merge(l: RpoDigest, r: RpoDigest) -> RpoDigest {
Rpo256::merge(&[l, r])
}

View File

@@ -2,8 +2,7 @@
use super::{
hash::rpo::{Rpo256, RpoDigest},
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
Felt, StarkField, Word, EMPTY_WORD, ZERO,
Felt, Word, EMPTY_WORD, ZERO,
};
// REEXPORTS
@@ -12,9 +11,6 @@ use super::{
mod empty_roots;
pub use empty_roots::EmptySubtreeRoots;
mod delta;
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
mod index;
pub use index::NodeIndex;
@@ -24,14 +20,16 @@ pub use merkle_tree::{path_to_text, tree_to_text, MerkleTree};
mod path;
pub use path::{MerklePath, RootPath, ValuePath};
mod simple_smt;
pub use simple_smt::SimpleSmt;
mod tiered_smt;
pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError};
mod smt;
#[cfg(feature = "internal")]
pub use smt::build_subtree_for_bench;
pub use smt::{
LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError,
SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
mod mmr;
pub use mmr::{Mmr, MmrPeaks, MmrProof};
pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr};
mod store;
pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode};
@@ -59,6 +57,6 @@ const fn int_to_leaf(value: u64) -> Word {
}
#[cfg(test)]
fn digests_to_words(digests: &[RpoDigest]) -> Vec<Word> {
fn digests_to_words(digests: &[RpoDigest]) -> alloc::vec::Vec<Word> {
digests.iter().map(|d| d.into()).collect()
}

View File

@@ -1,4 +1,4 @@
use crate::hash::rpo::RpoDigest;
use super::RpoDigest;
/// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)]

View File

@@ -1,13 +1,18 @@
use super::{
BTreeMap, BTreeSet, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest,
ValuePath, Vec, Word, EMPTY_WORD,
};
use crate::utils::{
format, string::String, vec, word_to_hex, ByteReader, ByteWriter, Deserializable,
DeserializationError, Serializable,
use alloc::{
collections::{BTreeMap, BTreeSet},
string::String,
vec::Vec,
};
use core::fmt;
use super::{
InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Word,
EMPTY_WORD,
};
use crate::utils::{
word_to_hex, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
};
#[cfg(test)]
mod tests;
@@ -109,9 +114,9 @@ impl PartialMerkleTree {
// check if the number of leaves can be accommodated by the tree's depth; we use a min
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
let max = (1_u64 << 63) as usize;
let max = 2usize.pow(63);
if layers.len() > max {
return Err(MerkleError::InvalidNumEntries(max, layers.len()));
return Err(MerkleError::TooManyEntries(max));
}
// Get maximum depth
@@ -142,11 +147,12 @@ impl PartialMerkleTree {
let index = NodeIndex::new(depth, index_value)?;
// get hash of the current node
let node = nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index))?;
let node =
nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
// get hash of the sibling node
let sibling = nodes
.get(&index.sibling())
.ok_or(MerkleError::NodeNotInSet(index.sibling()))?;
.ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
// get parent hash
let parent = Rpo256::merge(&index.build_node(*node, *sibling));
@@ -179,7 +185,10 @@ impl PartialMerkleTree {
/// # Errors
/// Returns an error if the specified NodeIndex is not contained in the nodes map.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
self.nodes.get(&index).ok_or(MerkleError::NodeNotInSet(index)).map(|hash| *hash)
self.nodes
.get(&index)
.ok_or(MerkleError::NodeIndexNotFoundInTree(index))
.copied()
}
/// Returns true if provided index contains in the leaves set, false otherwise.
@@ -209,7 +218,7 @@ impl PartialMerkleTree {
/// # Errors
/// Returns an error if:
/// - the specified index has depth set to 0 or the depth is greater than the depth of this
/// Merkle tree.
/// Merkle tree.
/// - the specified index is not contained in the nodes map.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
@@ -219,7 +228,7 @@ impl PartialMerkleTree {
}
if !self.nodes.contains_key(&index) {
return Err(MerkleError::NodeNotInSet(index));
return Err(MerkleError::NodeIndexNotFoundInTree(index));
}
let mut path = Vec::new();
@@ -330,15 +339,16 @@ impl PartialMerkleTree {
if self.root() == EMPTY_DIGEST {
self.nodes.insert(ROOT_INDEX, root);
} else if self.root() != root {
return Err(MerkleError::ConflictingRoots([self.root(), root].to_vec()));
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: root,
});
}
Ok(())
}
/// Updates value of the leaf at the specified index returning the old leaf value.
/// By default the specified index is assumed to belong to the deepest layer. If the considered
/// node does not belong to the tree, the first node on the way to the root will be changed.
///
/// By default the specified index is assumed to belong to the deepest layer. If the considered
/// node does not belong to the tree, the first node on the way to the root will be changed.
@@ -347,6 +357,7 @@ impl PartialMerkleTree {
///
/// # Errors
/// Returns an error if:
/// - No entry exists at the specified index.
/// - The specified index is greater than the maximum number of nodes on the deepest layer.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
let mut node_index = NodeIndex::new(self.max_depth(), index)?;
@@ -362,7 +373,7 @@ impl PartialMerkleTree {
let old_value = self
.nodes
.insert(node_index, value.into())
.ok_or(MerkleError::NodeNotInSet(node_index))?;
.ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
// if the old value and new value are the same, there is nothing to update
if value == *old_value {

View File

@@ -1,9 +1,11 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{
super::{
digests_to_words, int_to_node, BTreeMap, DefaultMerkleStore as MerkleStore, MerkleTree,
NodeIndex, PartialMerkleTree,
digests_to_words, int_to_node, DefaultMerkleStore as MerkleStore, MerkleTree, NodeIndex,
PartialMerkleTree,
},
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath, Vec,
Deserializable, InnerNodeInfo, RpoDigest, Serializable, ValuePath,
};
// TEST DATA
@@ -209,7 +211,7 @@ fn get_paths() {
// Which have leaf nodes 20, 22, 23, 32 and 33. Hence overall we will have 5 paths -- one path
// for each leaf.
let leaves = vec![NODE20, NODE22, NODE23, NODE32, NODE33];
let leaves = [NODE20, NODE22, NODE23, NODE32, NODE33];
let expected_paths: Vec<(NodeIndex, ValuePath)> = leaves
.iter()
.map(|&leaf| {
@@ -257,7 +259,7 @@ fn leaves() {
let value32 = mt.get_node(NODE32).unwrap();
let value33 = mt.get_node(NODE33).unwrap();
let leaves = vec![(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
let leaves = [(NODE11, value11), (NODE20, value20), (NODE32, value32), (NODE33, value33)];
let expected_leaves = leaves.iter().copied();
assert!(expected_leaves.eq(pmt.leaves()));
@@ -293,7 +295,8 @@ fn leaves() {
assert!(expected_leaves.eq(pmt.leaves()));
}
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected ones.
/// Checks that nodes of the PMT returned by `inner_nodes()` function are equal to the expected
/// ones.
#[test]
fn test_inner_node_iterator() {
let mt = MerkleTree::new(digests_to_words(&VALUES8)).unwrap();

View File

@@ -1,6 +1,12 @@
use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec};
use alloc::vec::Vec;
use core::ops::{Deref, DerefMut};
use super::{InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest};
use crate::{
utils::{ByteReader, Deserializable, DeserializationError, Serializable},
Word,
};
// MERKLE PATH
// ================================================================================================
@@ -17,6 +23,7 @@ impl MerklePath {
/// Creates a new Merkle path from a list of nodes.
pub fn new(nodes: Vec<RpoDigest>) -> Self {
assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items");
Self { nodes }
}
@@ -28,6 +35,11 @@ impl MerklePath {
self.nodes.len() as u8
}
/// Returns a reference to the [MerklePath]'s nodes.
pub fn nodes(&self) -> &[RpoDigest] {
&self.nodes
}
/// Computes the merkle root for this opening.
pub fn compute_root(&self, index: u64, node: RpoDigest) -> Result<RpoDigest, MerkleError> {
let mut index = NodeIndex::new(self.depth(), index)?;
@@ -42,12 +54,20 @@ impl MerklePath {
/// Verifies the Merkle opening proof towards the provided root.
///
/// Returns `true` if `node` exists at `index` in a Merkle tree with `root`.
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> bool {
match self.compute_root(index, node) {
Ok(computed_root) => root == &computed_root,
Err(_) => false,
/// # Errors
/// Returns an error if:
/// - provided node index is invalid.
/// - root calculated during the verification differs from the provided one.
pub fn verify(&self, index: u64, node: RpoDigest, root: &RpoDigest) -> Result<(), MerkleError> {
let computed_root = self.compute_root(index, node)?;
if &computed_root != root {
return Err(MerkleError::ConflictingRoots {
expected_root: *root,
actual_root: computed_root,
});
}
Ok(())
}
/// Returns an iterator over every inner node of this [MerklePath].
@@ -69,6 +89,9 @@ impl MerklePath {
}
}
// CONVERSIONS
// ================================================================================================
impl From<MerklePath> for Vec<RpoDigest> {
fn from(path: MerklePath) -> Self {
path.nodes
@@ -81,6 +104,12 @@ impl From<Vec<RpoDigest>> for MerklePath {
}
}
impl From<&[RpoDigest]> for MerklePath {
fn from(path: &[RpoDigest]) -> Self {
Self::new(path.to_vec())
}
}
impl Deref for MerklePath {
// we use `Vec` here instead of slice so we can call vector mutation methods directly from the
// merkle path (example: `Vec::remove`).
@@ -108,7 +137,7 @@ impl FromIterator<RpoDigest> for MerklePath {
impl IntoIterator for MerklePath {
type Item = RpoDigest;
type IntoIter = vec::IntoIter<RpoDigest>;
type IntoIter = alloc::vec::IntoIter<RpoDigest>;
fn into_iter(self) -> Self::IntoIter {
self.nodes.into_iter()
@@ -122,7 +151,7 @@ pub struct InnerNodeIterator<'a> {
value: RpoDigest,
}
impl<'a> Iterator for InnerNodeIterator<'a> {
impl Iterator for InnerNodeIterator<'_> {
type Item = InnerNodeInfo;
fn next(&mut self) -> Option<Self::Item> {
@@ -147,7 +176,7 @@ impl<'a> Iterator for InnerNodeIterator<'a> {
// MERKLE PATH CONTAINERS
// ================================================================================================
/// A container for a [Word] value and its [MerklePath] opening.
/// A container for a [crate::Word] value and its [MerklePath] opening.
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ValuePath {
/// The node value opening for `path`.
@@ -158,12 +187,18 @@ pub struct ValuePath {
impl ValuePath {
/// Returns a new [ValuePath] instantiated from the specified value and path.
pub fn new(value: RpoDigest, path: Vec<RpoDigest>) -> Self {
Self { value, path: MerklePath::new(path) }
pub fn new(value: RpoDigest, path: MerklePath) -> Self {
Self { value, path }
}
}
/// A container for a [MerklePath] and its [Word] root.
impl From<(MerklePath, Word)> for ValuePath {
fn from((path, value): (MerklePath, Word)) -> Self {
ValuePath::new(value.into(), path)
}
}
/// A container for a [MerklePath] and its [crate::Word] root.
///
/// This structure does not provide any guarantees regarding the correctness of the path to the
/// root. For more information, check [MerklePath::verify].
@@ -175,6 +210,55 @@ pub struct RootPath {
pub path: MerklePath,
}
// SERIALIZATION
// ================================================================================================
impl Serializable for MerklePath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the constructor");
target.write_u8(self.nodes.len() as u8);
target.write_many(&self.nodes);
}
}
impl Deserializable for MerklePath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let count = source.read_u8()?.into();
let nodes = source.read_many::<RpoDigest>(count)?;
Ok(Self { nodes })
}
}
impl Serializable for ValuePath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.value.write_into(target);
self.path.write_into(target);
}
}
impl Deserializable for ValuePath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let value = RpoDigest::read_from(source)?;
let path = MerklePath::read_from(source)?;
Ok(Self { value, path })
}
}
impl Serializable for RootPath {
fn write_into<W: winter_utils::ByteWriter>(&self, target: &mut W) {
self.root.write_into(target);
self.path.write_into(target);
}
}
impl Deserializable for RootPath {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let root = RpoDigest::read_from(source)?;
let path = MerklePath::read_from(source)?;
Ok(Self { root, path })
}
}
// TESTS
// ================================================================================================

View File

@@ -1,302 +0,0 @@
use super::{
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
};
#[cfg(test)]
mod tests;
// SPARSE MERKLE TREE
// ================================================================================================
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
///
/// The root of the tree is recomputed on each new leaf update.
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt {
depth: u8,
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
branches: BTreeMap<NodeIndex, BranchNode>,
empty_hashes: Vec<RpoDigest>,
}
impl SimpleSmt {
// CONSTANTS
// --------------------------------------------------------------------------------------------
/// Minimum supported depth.
pub const MIN_DEPTH: u8 = 1;
/// Maximum supported depth.
pub const MAX_DEPTH: u8 = 64;
/// Value of an empty leaf.
pub const EMPTY_VALUE: Word = super::EMPTY_WORD;
// CONSTRUCTORS
// --------------------------------------------------------------------------------------------
/// Returns a new [SimpleSmt] instantiated with the specified depth.
///
/// All leaves in the returned tree are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if the depth is 0 or is greater than 64.
pub fn new(depth: u8) -> Result<Self, MerkleError> {
// validate the range of the depth.
if depth < Self::MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(depth));
} else if Self::MAX_DEPTH < depth {
return Err(MerkleError::DepthTooBig(depth as u64));
}
let empty_hashes = EmptySubtreeRoots::empty_hashes(depth).to_vec();
let root = empty_hashes[0];
Ok(Self {
root,
depth,
empty_hashes,
leaves: BTreeMap::new(),
branches: BTreeMap::new(),
})
}
/// Returns a new [SimpleSmt] instantiated with the specified depth and with leaves
/// set as specified by the provided entries.
///
/// All leaves omitted from the entries list are set to [ZERO; 4].
///
/// # Errors
/// Returns an error if:
/// - If the depth is 0 or is greater than 64.
/// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
/// - The provided entries contain multiple values for the same key.
pub fn with_leaves<R, I>(depth: u8, entries: R) -> Result<Self, MerkleError>
where
R: IntoIterator<IntoIter = I>,
I: Iterator<Item = (u64, Word)> + ExactSizeIterator,
{
// create an empty tree
let mut tree = Self::new(depth)?;
// check if the number of leaves can be accommodated by the tree's depth; we use a min
// depth of 63 because we consider passing in a vector of size 2^64 infeasible.
let entries = entries.into_iter();
let max = 1 << tree.depth.min(63);
if entries.len() > max {
return Err(MerkleError::InvalidNumEntries(max, entries.len()));
}
// append leaves to the tree returning an error if a duplicate entry for the same key
// is found
let mut empty_entries = BTreeSet::new();
for (key, value) in entries {
let old_value = tree.update_leaf(key, value)?;
if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(key));
}
// if we've processed an empty entry, add the key to the set of empty entry keys, and
// if this key was already in the set, return an error
if value == Self::EMPTY_VALUE && !empty_entries.insert(key) {
return Err(MerkleError::DuplicateValuesForIndex(key));
}
}
Ok(tree)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
/// Returns the root of this Merkle tree.
pub const fn root(&self) -> RpoDigest {
self.root
}
/// Returns the depth of this Merkle tree.
pub const fn depth(&self) -> u8 {
self.depth
}
/// Returns a node at the specified index.
///
/// # Errors
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
/// the depth of this Merkle tree.
pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
if index.is_root() {
Err(MerkleError::DepthTooSmall(index.depth()))
} else if index.depth() > self.depth() {
Err(MerkleError::DepthTooBig(index.depth() as u64))
} else if index.depth() == self.depth() {
// the lookup in empty_hashes could fail only if empty_hashes were not built correctly
// by the constructor as we check the depth of the lookup above.
Ok(RpoDigest::from(
self.get_leaf_node(index.value())
.unwrap_or_else(|| *self.empty_hashes[index.depth() as usize]),
))
} else {
Ok(self.get_branch_node(&index).parent())
}
}
/// Returns a value of the leaf at the specified index.
///
/// # Errors
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
pub fn get_leaf(&self, index: u64) -> Result<Word, MerkleError> {
let index = NodeIndex::new(self.depth, index)?;
Ok(self.get_node(index)?.into())
}
/// Returns a Merkle path from the node at the specified index to the root.
///
/// The node itself is not included in the path.
///
/// # Errors
/// Returns an error if the specified index has depth set to 0 or the depth is greater than
/// the depth of this Merkle tree.
pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
if index.is_root() {
return Err(MerkleError::DepthTooSmall(index.depth()));
} else if index.depth() > self.depth() {
return Err(MerkleError::DepthTooBig(index.depth() as u64));
}
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let BranchNode { left, right } = self.get_branch_node(&index);
let value = if is_right { left } else { right };
path.push(value);
}
Ok(MerklePath::new(path))
}
/// Return a Merkle path from the leaf at the specified index to the root.
///
/// The leaf itself is not included in the path.
///
/// # Errors
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
pub fn get_leaf_path(&self, index: u64) -> Result<MerklePath, MerkleError> {
let index = NodeIndex::new(self.depth(), index)?;
self.get_path(index)
}
// ITERATORS
// --------------------------------------------------------------------------------------------
/// Returns an iterator over the leaves of this [SimpleSmt].
pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
self.leaves.iter().map(|(i, w)| (*i, w))
}
/// Returns an iterator over the inner nodes of this Merkle tree.
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.branches.values().map(|e| InnerNodeInfo {
value: e.parent(),
left: e.left,
right: e.right,
})
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
/// Updates value of the leaf at the specified index returning the old leaf value.
///
/// This also recomputes all hashes between the leaf and the root, updating the root itself.
///
/// # Errors
/// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}.
pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE);
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return Ok(value);
}
let mut index = NodeIndex::new(self.depth(), index)?;
let mut value = RpoDigest::from(value);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let BranchNode { left, right } = self.get_branch_node(&index);
let (left, right) = if is_right { (left, value) } else { (value, right) };
self.insert_branch_node(index, left, right);
value = Rpo256::merge(&[left, right]);
}
self.root = value;
Ok(old_value)
}
// HELPER METHODS
// --------------------------------------------------------------------------------------------
fn get_leaf_node(&self, key: u64) -> Option<Word> {
self.leaves.get(&key).copied()
}
fn insert_leaf_node(&mut self, key: u64, node: Word) -> Option<Word> {
self.leaves.insert(key, node)
}
fn get_branch_node(&self, index: &NodeIndex) -> BranchNode {
self.branches.get(index).cloned().unwrap_or_else(|| {
let node = self.empty_hashes[index.depth() as usize + 1];
BranchNode { left: node, right: node }
})
}
fn insert_branch_node(&mut self, index: NodeIndex, left: RpoDigest, right: RpoDigest) {
let branch = BranchNode { left, right };
self.branches.insert(index, branch);
}
}
// BRANCH NODE
// ================================================================================================
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
struct BranchNode {
left: RpoDigest,
right: RpoDigest,
}
impl BranchNode {
fn parent(&self) -> RpoDigest {
Rpo256::merge(&[self.left, self.right])
}
}
// TRY APPLY DIFF
// ================================================================================================
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
type Error = MerkleError;
type DiffType = MerkleTreeDelta;
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
if diff.depth() != self.depth() {
return Err(MerkleError::InvalidDepth {
expected: self.depth(),
provided: diff.depth(),
});
}
for slot in diff.cleared_slots() {
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
}
for (slot, value) in diff.updated_slots() {
self.update_leaf(*slot, *value)?;
}
Ok(())
}
}

View File

@@ -1,251 +0,0 @@
use super::{
super::{InnerNodeInfo, MerkleError, MerkleTree, RpoDigest, SimpleSmt, EMPTY_WORD},
NodeIndex, Rpo256, Vec,
};
use crate::{
merkle::{digests_to_words, int_to_leaf, int_to_node},
Word,
};
// TEST DATA
// ================================================================================================
const KEYS4: [u64; 4] = [0, 1, 2, 3];
const KEYS8: [u64; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
const VALUES4: [RpoDigest; 4] = [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
const VALUES8: [RpoDigest; 8] = [
int_to_node(1),
int_to_node(2),
int_to_node(3),
int_to_node(4),
int_to_node(5),
int_to_node(6),
int_to_node(7),
int_to_node(8),
];
const ZERO_VALUES8: [Word; 8] = [int_to_leaf(0); 8];
// TESTS
// ================================================================================================
#[test]
fn build_empty_tree() {
// tree of depth 3
let smt = SimpleSmt::new(3).unwrap();
let mt = MerkleTree::new(ZERO_VALUES8.to_vec()).unwrap();
assert_eq!(mt.root(), smt.root());
}
#[test]
fn build_sparse_tree() {
let mut smt = SimpleSmt::new(3).unwrap();
let mut values = ZERO_VALUES8.to_vec();
// insert single value
let key = 6;
let new_node = int_to_leaf(7);
values[key as usize] = new_node;
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
let mt2 = MerkleTree::new(values.clone()).unwrap();
assert_eq!(mt2.root(), smt.root());
assert_eq!(
mt2.get_path(NodeIndex::make(3, 6)).unwrap(),
smt.get_path(NodeIndex::make(3, 6)).unwrap()
);
assert_eq!(old_value, EMPTY_WORD);
// insert second value at distinct leaf branch
let key = 2;
let new_node = int_to_leaf(3);
values[key as usize] = new_node;
let old_value = smt.update_leaf(key, new_node).expect("Failed to update leaf");
let mt3 = MerkleTree::new(values).unwrap();
assert_eq!(mt3.root(), smt.root());
assert_eq!(
mt3.get_path(NodeIndex::make(3, 2)).unwrap(),
smt.get_path(NodeIndex::make(3, 2)).unwrap()
);
assert_eq!(old_value, EMPTY_WORD);
}
#[test]
fn test_depth2_tree() {
let tree =
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
.unwrap();
// check internal structure
let (root, node2, node3) = compute_internal_nodes();
assert_eq!(root, tree.root());
assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
// check get_node()
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
// check get_path(): depth 2
assert_eq!(vec![VALUES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
assert_eq!(vec![VALUES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
assert_eq!(vec![VALUES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
assert_eq!(vec![VALUES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
// check get_path(): depth 1
assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
}
#[test]
fn test_inner_node_iterator() -> Result<(), MerkleError> {
let tree =
SimpleSmt::with_leaves(2, KEYS4.into_iter().zip(digests_to_words(&VALUES4).into_iter()))
.unwrap();
// check depth 2
assert_eq!(VALUES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
assert_eq!(VALUES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
assert_eq!(VALUES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
assert_eq!(VALUES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
// get parent nodes
let root = tree.root();
let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let expected = vec![
InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
];
assert_eq!(nodes, expected);
Ok(())
}
#[test]
fn update_leaf() {
let mut tree =
SimpleSmt::with_leaves(3, KEYS8.into_iter().zip(digests_to_words(&VALUES8).into_iter()))
.unwrap();
// update one value
let key = 3;
let new_node = int_to_leaf(9);
let mut expected_values = digests_to_words(&VALUES8);
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
// update another value
let key = 6;
let new_node = int_to_leaf(10);
expected_values[key] = new_node;
let expected_tree = MerkleTree::new(expected_values.clone()).unwrap();
let old_leaf = tree.update_leaf(key as u64, new_node).unwrap();
assert_eq!(expected_tree.root(), tree.root);
assert_eq!(old_leaf, *VALUES8[key]);
}
#[test]
fn small_tree_opening_is_consistent() {
// ____k____
// / \
// _i_ _j_
// / \ / \
// e f g h
// / \ / \ / \ / \
// a b 0 0 c 0 0 d
let z = EMPTY_WORD;
let a = Word::from(Rpo256::merge(&[z.into(); 2]));
let b = Word::from(Rpo256::merge(&[a.into(); 2]));
let c = Word::from(Rpo256::merge(&[b.into(); 2]));
let d = Word::from(Rpo256::merge(&[c.into(); 2]));
let e = Rpo256::merge(&[a.into(), b.into()]);
let f = Rpo256::merge(&[z.into(), z.into()]);
let g = Rpo256::merge(&[c.into(), z.into()]);
let h = Rpo256::merge(&[z.into(), d.into()]);
let i = Rpo256::merge(&[e, f]);
let j = Rpo256::merge(&[g, h]);
let k = Rpo256::merge(&[i, j]);
let depth = 3;
let entries = vec![(0, a), (1, b), (4, c), (7, d)];
let tree = SimpleSmt::with_leaves(depth, entries).unwrap();
assert_eq!(tree.root(), k);
let cases: Vec<(u8, u64, Vec<RpoDigest>)> = vec![
(3, 0, vec![b.into(), f, j]),
(3, 1, vec![a.into(), f, j]),
(3, 4, vec![z.into(), h, i]),
(3, 7, vec![z.into(), g, i]),
(2, 0, vec![f, j]),
(2, 1, vec![e, j]),
(2, 2, vec![h, i]),
(2, 3, vec![g, i]),
(1, 0, vec![j]),
(1, 1, vec![i]),
];
for (depth, key, path) in cases {
let opening = tree.get_path(NodeIndex::make(depth, key)).unwrap();
assert_eq!(path, *opening);
}
}
#[test]
fn fail_on_duplicates() {
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(3))];
let smt = SimpleSmt::with_leaves(64, entries);
assert!(smt.is_err());
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
let smt = SimpleSmt::with_leaves(64, entries);
assert!(smt.is_err());
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(1))];
let smt = SimpleSmt::with_leaves(64, entries);
assert!(smt.is_err());
let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))];
let smt = SimpleSmt::with_leaves(64, entries);
assert!(smt.is_err());
}
#[test]
fn with_no_duplicates_empty_node() {
let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2))];
let smt = SimpleSmt::with_leaves(64, entries);
assert!(smt.is_ok());
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
let node2 = Rpo256::merge(&[VALUES4[0], VALUES4[1]]);
let node3 = Rpo256::merge(&[VALUES4[2], VALUES4[3]]);
let root = Rpo256::merge(&[node2, node3]);
(root, node2, node3)
}

View File

@@ -0,0 +1,39 @@
use thiserror::Error;
use crate::{
hash::rpo::RpoDigest,
merkle::{LeafIndex, SMT_DEPTH},
};
// SMT LEAF ERROR
// =================================================================================================
#[derive(Debug, Error)]
pub enum SmtLeafError {
#[error(
"multiple leaf requires all keys to map to the same leaf index but key1 {key_1} and key2 {key_2} map to different indices"
)]
InconsistentMultipleLeafKeys { key_1: RpoDigest, key_2: RpoDigest },
#[error("single leaf key {key} maps to {actual_leaf_index:?} but was expected to map to {expected_leaf_index:?}")]
InconsistentSingleLeafIndices {
key: RpoDigest,
expected_leaf_index: LeafIndex<SMT_DEPTH>,
actual_leaf_index: LeafIndex<SMT_DEPTH>,
},
#[error("supplied leaf index {leaf_index_supplied:?} does not match {leaf_index_from_keys:?} for multiple leaf")]
InconsistentMultipleLeafIndices {
leaf_index_from_keys: LeafIndex<SMT_DEPTH>,
leaf_index_supplied: LeafIndex<SMT_DEPTH>,
},
#[error("multiple leaf requires at least two entries but only {0} were given")]
MultipleLeafRequiresTwoEntries(usize),
}
// SMT PROOF ERROR
// =================================================================================================
#[derive(Debug, Error)]
pub enum SmtProofError {
#[error("merkle path length {0} does not match SMT depth {SMT_DEPTH}")]
InvalidMerklePathLength(usize),
}

373
src/merkle/smt/full/leaf.rs Normal file
View File

@@ -0,0 +1,373 @@
use alloc::{string::ToString, vec::Vec};
use core::cmp::Ordering;
use super::{Felt, LeafIndex, Rpo256, RpoDigest, SmtLeafError, Word, EMPTY_WORD, SMT_DEPTH};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum SmtLeaf {
Empty(LeafIndex<SMT_DEPTH>),
Single((RpoDigest, Word)),
Multiple(Vec<(RpoDigest, Word)>),
}
impl SmtLeaf {
// CONSTRUCTORS
// ---------------------------------------------------------------------------------------------
/// Returns a new leaf with the specified entries
///
/// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index
/// - Returns an error if 1 or more keys in `entries` map to a leaf index different from
/// `leaf_index`
pub fn new(
entries: Vec<(RpoDigest, Word)>,
leaf_index: LeafIndex<SMT_DEPTH>,
) -> Result<Self, SmtLeafError> {
match entries.len() {
0 => Ok(Self::new_empty(leaf_index)),
1 => {
let (key, value) = entries[0];
let computed_index = LeafIndex::<SMT_DEPTH>::from(key);
if computed_index != leaf_index {
return Err(SmtLeafError::InconsistentSingleLeafIndices {
key,
expected_leaf_index: leaf_index,
actual_leaf_index: computed_index,
});
}
Ok(Self::new_single(key, value))
},
_ => {
let leaf = Self::new_multiple(entries)?;
// `new_multiple()` checked that all keys map to the same leaf index. We still need
// to ensure that that leaf index is `leaf_index`.
if leaf.index() != leaf_index {
Err(SmtLeafError::InconsistentMultipleLeafIndices {
leaf_index_from_keys: leaf.index(),
leaf_index_supplied: leaf_index,
})
} else {
Ok(leaf)
}
},
}
}
/// Returns a new empty leaf with the specified leaf index
pub fn new_empty(leaf_index: LeafIndex<SMT_DEPTH>) -> Self {
Self::Empty(leaf_index)
}
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// entry's key.
pub fn new_single(key: RpoDigest, value: Word) -> Self {
Self::Single((key, value))
}
/// Returns a new single leaf with the specified entry. The leaf index is derived from the
/// entries' keys.
///
/// # Errors
/// - Returns an error if 2 keys in `entries` map to a different leaf index
pub fn new_multiple(entries: Vec<(RpoDigest, Word)>) -> Result<Self, SmtLeafError> {
if entries.len() < 2 {
return Err(SmtLeafError::MultipleLeafRequiresTwoEntries(entries.len()));
}
// Check that all keys map to the same leaf index
{
let mut keys = entries.iter().map(|(key, _)| key);
let first_key = *keys.next().expect("ensured at least 2 entries");
let first_leaf_index: LeafIndex<SMT_DEPTH> = first_key.into();
for &next_key in keys {
let next_leaf_index: LeafIndex<SMT_DEPTH> = next_key.into();
if next_leaf_index != first_leaf_index {
return Err(SmtLeafError::InconsistentMultipleLeafKeys {
key_1: first_key,
key_2: next_key,
});
}
}
}
Ok(Self::Multiple(entries))
}
// PUBLIC ACCESSORS
// ---------------------------------------------------------------------------------------------
/// Returns true if the leaf is empty
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty(_))
}
/// Returns the leaf's index in the [`super::Smt`]
pub fn index(&self) -> LeafIndex<SMT_DEPTH> {
match self {
SmtLeaf::Empty(leaf_index) => *leaf_index,
SmtLeaf::Single((key, _)) => key.into(),
SmtLeaf::Multiple(entries) => {
// Note: All keys are guaranteed to have the same leaf index
let (first_key, _) = entries[0];
first_key.into()
},
}
}
/// Returns the number of entries stored in the leaf
pub fn num_entries(&self) -> u64 {
match self {
SmtLeaf::Empty(_) => 0,
SmtLeaf::Single(_) => 1,
SmtLeaf::Multiple(entries) => {
entries.len().try_into().expect("shouldn't have more than 2^64 entries")
},
}
}
/// Computes the hash of the leaf
pub fn hash(&self) -> RpoDigest {
match self {
SmtLeaf::Empty(_) => EMPTY_WORD.into(),
SmtLeaf::Single((key, value)) => Rpo256::merge(&[*key, value.into()]),
SmtLeaf::Multiple(kvs) => {
let elements: Vec<Felt> = kvs.iter().copied().flat_map(kv_to_elements).collect();
Rpo256::hash_elements(&elements)
},
}
}
// ITERATORS
// ---------------------------------------------------------------------------------------------
/// Returns the key-value pairs in the leaf
pub fn entries(&self) -> Vec<&(RpoDigest, Word)> {
match self {
SmtLeaf::Empty(_) => Vec::new(),
SmtLeaf::Single(kv_pair) => vec![kv_pair],
SmtLeaf::Multiple(kv_pairs) => kv_pairs.iter().collect(),
}
}
// CONVERSIONS
// ---------------------------------------------------------------------------------------------
/// Converts a leaf to a list of field elements
pub fn to_elements(&self) -> Vec<Felt> {
self.clone().into_elements()
}
/// Converts a leaf to a list of field elements
pub fn into_elements(self) -> Vec<Felt> {
self.into_entries().into_iter().flat_map(kv_to_elements).collect()
}
/// Converts a leaf the key-value pairs in the leaf
pub fn into_entries(self) -> Vec<(RpoDigest, Word)> {
match self {
SmtLeaf::Empty(_) => Vec::new(),
SmtLeaf::Single(kv_pair) => vec![kv_pair],
SmtLeaf::Multiple(kv_pairs) => kv_pairs,
}
}
// HELPERS
// ---------------------------------------------------------------------------------------------
/// Returns the value associated with `key` in the leaf, or `None` if `key` maps to another
/// leaf.
pub(super) fn get_value(&self, key: &RpoDigest) -> Option<Word> {
// Ensure that `key` maps to this leaf
if self.index() != key.into() {
return None;
}
match self {
SmtLeaf::Empty(_) => Some(EMPTY_WORD),
SmtLeaf::Single((key_in_leaf, value_in_leaf)) => {
if key == key_in_leaf {
Some(*value_in_leaf)
} else {
Some(EMPTY_WORD)
}
},
SmtLeaf::Multiple(kv_pairs) => {
for (key_in_leaf, value_in_leaf) in kv_pairs {
if key == key_in_leaf {
return Some(*value_in_leaf);
}
}
Some(EMPTY_WORD)
},
}
}
/// Inserts key-value pair into the leaf; returns the previous value associated with `key`, if
/// any.
///
/// The caller needs to ensure that `key` has the same leaf index as all other keys in the leaf
pub(super) fn insert(&mut self, key: RpoDigest, value: Word) -> Option<Word> {
match self {
SmtLeaf::Empty(_) => {
*self = SmtLeaf::new_single(key, value);
None
},
SmtLeaf::Single(kv_pair) => {
if kv_pair.0 == key {
// the key is already in this leaf. Update the value and return the previous
// value
let old_value = kv_pair.1;
kv_pair.1 = value;
Some(old_value)
} else {
// Another entry is present in this leaf. Transform the entry into a list
// entry, and make sure the key-value pairs are sorted by key
let mut pairs = vec![*kv_pair, (key, value)];
pairs.sort_by(|(key_1, _), (key_2, _)| cmp_keys(*key_1, *key_2));
*self = SmtLeaf::Multiple(pairs);
None
}
},
SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => {
let old_value = kv_pairs[pos].1;
kv_pairs[pos].1 = value;
Some(old_value)
},
Err(pos) => {
kv_pairs.insert(pos, (key, value));
None
},
}
},
}
}
/// Removes key-value pair from the leaf stored at key; returns the previous value associated
/// with `key`, if any. Also returns an `is_empty` flag, indicating whether the leaf became
/// empty, and must be removed from the data structure it is contained in.
pub(super) fn remove(&mut self, key: RpoDigest) -> (Option<Word>, bool) {
match self {
SmtLeaf::Empty(_) => (None, false),
SmtLeaf::Single((key_at_leaf, value_at_leaf)) => {
if *key_at_leaf == key {
// our key was indeed stored in the leaf, so we return the value that was stored
// in it, and indicate that the leaf should be removed
let old_value = *value_at_leaf;
// Note: this is not strictly needed, since the caller is expected to drop this
// `SmtLeaf` object.
*self = SmtLeaf::new_empty(key.into());
(Some(old_value), true)
} else {
// another key is stored at leaf; nothing to update
(None, false)
}
},
SmtLeaf::Multiple(kv_pairs) => {
match kv_pairs.binary_search_by(|kv_pair| cmp_keys(kv_pair.0, key)) {
Ok(pos) => {
let old_value = kv_pairs[pos].1;
kv_pairs.remove(pos);
debug_assert!(!kv_pairs.is_empty());
if kv_pairs.len() == 1 {
// convert the leaf into `Single`
*self = SmtLeaf::Single(kv_pairs[0]);
}
(Some(old_value), false)
},
Err(_) => {
// other keys are stored at leaf; nothing to update
(None, false)
},
}
},
}
}
}
impl Serializable for SmtLeaf {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
// Write: num entries
self.num_entries().write_into(target);
// Write: leaf index
let leaf_index: u64 = self.index().value();
leaf_index.write_into(target);
// Write: entries
for (key, value) in self.entries() {
key.write_into(target);
value.write_into(target);
}
}
}
impl Deserializable for SmtLeaf {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
// Read: num entries
let num_entries = source.read_u64()?;
// Read: leaf index
let leaf_index: LeafIndex<SMT_DEPTH> = {
let value = source.read_u64()?;
LeafIndex::new_max_depth(value)
};
// Read: entries
let mut entries: Vec<(RpoDigest, Word)> = Vec::new();
for _ in 0..num_entries {
let key: RpoDigest = source.read()?;
let value: Word = source.read()?;
entries.push((key, value));
}
Self::new(entries, leaf_index)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
// HELPER FUNCTIONS
// ================================================================================================
/// Converts a key-value tuple to an iterator of `Felt`s
pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
let key_elements = key.into_iter();
let value_elements = value.into_iter();
key_elements.chain(value_elements)
}
/// Compares two keys, compared element-by-element using their integer representations starting with
/// the most significant element.
pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
let v1 = v1.as_int();
let v2 = v2.as_int();
if v1 != v2 {
return v1.cmp(&v2);
}
}
Ordering::Equal
}

Some files were not shown because too many files have changed in this diff Show More