commit 3ea668521986e239dc12034f76424869a12c47d3 Author: Nicholas Ward Date: Fri Mar 3 16:06:04 2023 -0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ff35523 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +# Cargo build +/target +Cargo.lock + +# Profile-guided optimization +/tmp +pgo-data.profdata + +# MacOS nuisances +.DS_Store + diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0a15665 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "plonky2_ecdsa" +description = "ECDSA gadget for Plonky2" +version = "0.1.0" +license = "MIT OR Apache-2.0" +edition = "2021" + +[features] +parallel = ["plonky2_maybe_rayon/parallel", "plonky2/parallel"] + +[dependencies] +anyhow = { version = "1.0.40", default-features = false } +itertools = { version = "0.10.0", default-features = false } +plonky2_maybe_rayon = { version = "0.1.0", default-features = false } +num = { version = "0.4.0", default-features = false } +plonky2 = { version = "0.1.2", default-features = false } +plonky2_u32 = { version = "0.1.0", default-features = false } +serde = { version = "1.0", default-features = false, features = ["derive"] } + +[dev-dependencies] +rand = { version = "0.8.4", default-features = false, features = ["getrandom"] } diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..1e5006d --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..86d690b --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2022 The Plonky2 Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..bb4e2d8 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +## License + +Licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/src/curve/curve_adds.rs b/src/curve/curve_adds.rs new file mode 100644 index 0000000..319c561 --- /dev/null +++ b/src/curve/curve_adds.rs @@ -0,0 +1,158 @@ +use core::ops::Add; + +use plonky2::field::ops::Square; +use plonky2::field::types::Field; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: ProjectivePoint) -> Self::Output { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + } = rhs; + + if z1 == C::BaseField::ZERO { + return rhs; + } + if z2 == C::BaseField::ZERO { + return self; + } + + let x1z2 = x1 * z2; + let y1z2 = y1 * z2; + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1z2 == x2z1 { + if y1z2 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1z2 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/add-1998-cmo-2 + let z1z2 = z1 * z2; + let u = y2z1 - y1z2; + let uu = u.square(); + let v = x2z1 - x1z2; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1z2; + let a = uu * z1z2 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1z2; + let z3 = vvv * z1z2; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; + + if z1 == C::BaseField::ZERO { + return rhs.to_projective(); + } + if zero2 { + return self; + } + + let x2z1 = x2 * z1; + let y2z1 = y2 * z1; + + // Check if we're doubling or adding inverses. + if x1 == x2z1 { + if y1 == y2z1 { + // TODO: inline to avoid redundant muls. + return self.double(); + } + if y1 == -y2z1 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/madd-1998-cmo + let u = y2z1 - y1; + let uu = u.square(); + let v = x2z1 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu * z1 - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv * z1; + ProjectivePoint::nonzero(x3, y3, z3) + } +} + +impl Add> for AffinePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = rhs; + + if zero1 { + return rhs.to_projective(); + } + if zero2 { + return self.to_projective(); + } + + // Check if we're doubling or adding inverses. + if x1 == x2 { + if y1 == y2 { + return self.to_projective().double(); + } + if y1 == -y2 { + return ProjectivePoint::ZERO; + } + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/addition/mmadd-1998-cmo + let u = y2 - y1; + let uu = u.square(); + let v = x2 - x1; + let vv = v.square(); + let vvv = v * vv; + let r = vv * x1; + let a = uu - vvv - r.double(); + let x3 = v * a; + let y3 = u * (r - a) - vvv * y1; + let z3 = vvv; + ProjectivePoint::nonzero(x3, y3, z3) + } +} diff --git a/src/curve/curve_msm.rs b/src/curve/curve_msm.rs new file mode 100644 index 0000000..9faa4a7 --- /dev/null +++ b/src/curve/curve_msm.rs @@ -0,0 +1,265 @@ +use alloc::vec::Vec; + +use itertools::Itertools; +use plonky2::field::types::{Field, PrimeField}; +use plonky2_maybe_rayon::*; + +use crate::curve::curve_summation::affine_multisummation_best; +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +/// In Yao's method, we compute an affine summation for each digit. In a parallel setting, it would +/// be easiest to assign individual summations to threads, but this would be sub-optimal because +/// multi-summations can be more efficient than repeating individual summations (see +/// `affine_multisummation_best`). Thus we divide digits into large chunks, and assign chunks of +/// digits to threads. Note that there is a delicate balance here, as large chunks can result in +/// uneven distributions of work among threads. +const DIGITS_PER_CHUNK: usize = 80; + +#[derive(Clone, Debug)] +pub struct MsmPrecomputation { + /// For each generator (in the order they were passed to `msm_precompute`), contains a vector + /// of powers, i.e. [(2^w)^i] for i < DIGITS. + // TODO: Use compressed coordinates here. + powers_per_generator: Vec>>, + + /// The window size. + w: usize, +} + +pub fn msm_precompute( + generators: &[ProjectivePoint], + w: usize, +) -> MsmPrecomputation { + MsmPrecomputation { + powers_per_generator: generators + .into_par_iter() + .map(|&g| precompute_single_generator(g, w)) + .collect(), + w, + } +} + +fn precompute_single_generator(g: ProjectivePoint, w: usize) -> Vec> { + let digits = (C::ScalarField::BITS + w - 1) / w; + let mut powers: Vec> = Vec::with_capacity(digits); + powers.push(g); + for i in 1..digits { + let mut power_i_proj = powers[i - 1]; + for _j in 0..w { + power_i_proj = power_i_proj.double(); + } + powers.push(power_i_proj); + } + ProjectivePoint::batch_to_affine(&powers) +} + +pub fn msm_parallel( + scalars: &[C::ScalarField], + generators: &[ProjectivePoint], + w: usize, +) -> ProjectivePoint { + let precomputation = msm_precompute(generators, w); + msm_execute_parallel(&precomputation, scalars) +} + +pub fn msm_execute( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + + for digit in (1..base).rev() { + for &(i, j) in &digit_occurrences[digit] { + u = u + precomputation.powers_per_generator[i][j]; + } + y = y + u; + } + + y +} + +pub fn msm_execute_parallel( + precomputation: &MsmPrecomputation, + scalars: &[C::ScalarField], +) -> ProjectivePoint { + assert_eq!(precomputation.powers_per_generator.len(), scalars.len()); + let w = precomputation.w; + let digits = (C::ScalarField::BITS + w - 1) / w; + let base = 1 << w; + + // This is a variant of Yao's method, adapted to the multi-scalar setting. Because we use + // extremely large windows, the repeated scans in Yao's method could be more expensive than the + // actual group operations. To avoid this, we store a multimap from each possible digit to the + // positions in which that digit occurs in the scalars. These positions have the form (i, j), + // where i is the index of the generator and j is an index into the digits of the scalar + // associated with that generator. + let mut digit_occurrences: Vec> = Vec::with_capacity(digits); + for _i in 0..base { + digit_occurrences.push(Vec::new()); + } + for (i, scalar) in scalars.iter().enumerate() { + let digits = to_digits::(scalar, w); + for (j, &digit) in digits.iter().enumerate() { + digit_occurrences[digit].push((i, j)); + } + } + + // For each digit, we add up the powers associated with all occurrences that digit. + let digits: Vec = (0..base).collect(); + let digit_acc: Vec> = digits + .par_chunks(DIGITS_PER_CHUNK) + .flat_map(|chunk| { + let summations: Vec>> = chunk + .iter() + .map(|&digit| { + digit_occurrences[digit] + .iter() + .map(|&(i, j)| precomputation.powers_per_generator[i][j]) + .collect() + }) + .collect(); + affine_multisummation_best(summations) + }) + .collect(); + // println!("Computing the per-digit summations (in parallel) took {}s", start.elapsed().as_secs_f64()); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + for digit in (1..base).rev() { + u = u + digit_acc[digit]; + y = y + u; + } + // println!("Final summation (sequential) {}s", start.elapsed().as_secs_f64()); + y +} + +pub(crate) fn to_digits(x: &C::ScalarField, w: usize) -> Vec { + let scalar_bits = C::ScalarField::BITS; + let num_digits = (scalar_bits + w - 1) / w; + + // Convert x to a bool array. + let x_canonical: Vec<_> = x + .to_canonical_biguint() + .to_u64_digits() + .iter() + .cloned() + .pad_using(scalar_bits / 64, |_| 0) + .collect(); + let mut x_bits = Vec::with_capacity(scalar_bits); + for i in 0..scalar_bits { + x_bits.push((x_canonical[i / 64] >> (i as u64 % 64) & 1) != 0); + } + + let mut digits = Vec::with_capacity(num_digits); + for i in 0..num_digits { + let mut digit = 0; + for j in ((i * w)..((i + 1) * w).min(scalar_bits)).rev() { + digit <<= 1; + digit |= x_bits[j] as usize; + } + digits.push(digit); + } + digits +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use num::BigUint; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + + use super::*; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_to_digits() { + let x_canonical = [ + 0b10101010101010101010101010101010, + 0b10101010101010101010101010101010, + 0b11001100110011001100110011001100, + 0b11001100110011001100110011001100, + 0b11110000111100001111000011110000, + 0b11110000111100001111000011110000, + 0b00001111111111111111111111111111, + 0b11111111111111111111111111111111, + ]; + let x = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&x_canonical)); + assert_eq!(x.to_canonical_biguint().to_u32_digits(), x_canonical); + assert_eq!( + to_digits::(&x, 17), + vec![ + 0b01010101010101010, + 0b10101010101010101, + 0b01010101010101010, + 0b11001010101010101, + 0b01100110011001100, + 0b00110011001100110, + 0b10011001100110011, + 0b11110000110011001, + 0b01111000011110000, + 0b00111100001111000, + 0b00011110000111100, + 0b11111111111111110, + 0b01111111111111111, + 0b11111111111111000, + 0b11111111111111111, + 0b1, + ] + ); + } + + #[test] + fn test_msm() { + let w = 5; + + let generator_1 = Secp256K1::GENERATOR_PROJECTIVE; + let generator_2 = generator_1 + generator_1; + let generator_3 = generator_1 + generator_2; + + let scalar_1 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ + 11111111, 22222222, 33333333, 44444444, + ])); + let scalar_2 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ + 22222222, 22222222, 33333333, 44444444, + ])); + let scalar_3 = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ + 33333333, 22222222, 33333333, 44444444, + ])); + + let generators = vec![generator_1, generator_2, generator_3]; + let scalars = vec![scalar_1, scalar_2, scalar_3]; + + let precomputation = msm_precompute(&generators, w); + let result_msm = msm_execute(&precomputation, &scalars); + + let result_naive = Secp256K1::convert(scalar_1) * generator_1 + + Secp256K1::convert(scalar_2) * generator_2 + + Secp256K1::convert(scalar_3) * generator_3; + + assert_eq!(result_msm, result_naive); + } +} diff --git a/src/curve/curve_multiplication.rs b/src/curve/curve_multiplication.rs new file mode 100644 index 0000000..1f9c653 --- /dev/null +++ b/src/curve/curve_multiplication.rs @@ -0,0 +1,100 @@ +use alloc::vec::Vec; +use core::ops::Mul; + +use plonky2::field::types::{Field, PrimeField}; + +use crate::curve::curve_types::{Curve, CurveScalar, ProjectivePoint}; + +const WINDOW_BITS: usize = 4; +const BASE: usize = 1 << WINDOW_BITS; + +fn digits_per_scalar() -> usize { + (C::ScalarField::BITS + WINDOW_BITS - 1) / WINDOW_BITS +} + +/// Precomputed state used for scalar x ProjectivePoint multiplications, +/// specific to a particular generator. +#[derive(Clone)] +pub struct MultiplicationPrecomputation { + /// [(2^w)^i] g for each i < digits_per_scalar. + powers: Vec>, +} + +impl ProjectivePoint { + pub fn mul_precompute(&self) -> MultiplicationPrecomputation { + let num_digits = digits_per_scalar::(); + let mut powers = Vec::with_capacity(num_digits); + powers.push(*self); + for i in 1..num_digits { + let mut power_i = powers[i - 1]; + for _j in 0..WINDOW_BITS { + power_i = power_i.double(); + } + powers.push(power_i); + } + + MultiplicationPrecomputation { powers } + } + + #[must_use] + pub fn mul_with_precomputation( + &self, + scalar: C::ScalarField, + precomputation: MultiplicationPrecomputation, + ) -> Self { + // Yao's method; see https://koclab.cs.ucsb.edu/teaching/ecc/eccPapers/Doche-ch09.pdf + let precomputed_powers = precomputation.powers; + + let digits = to_digits::(&scalar); + + let mut y = ProjectivePoint::ZERO; + let mut u = ProjectivePoint::ZERO; + let mut all_summands = Vec::new(); + for j in (1..BASE).rev() { + let mut u_summands = Vec::new(); + for (i, &digit) in digits.iter().enumerate() { + if digit == j as u64 { + u_summands.push(precomputed_powers[i]); + } + } + all_summands.push(u_summands); + } + + let all_sums: Vec> = all_summands + .iter() + .cloned() + .map(|vec| vec.iter().fold(ProjectivePoint::ZERO, |a, &b| a + b)) + .collect(); + for i in 0..all_sums.len() { + u = u + all_sums[i]; + y = y + u; + } + y + } +} + +impl Mul> for CurveScalar { + type Output = ProjectivePoint; + + fn mul(self, rhs: ProjectivePoint) -> Self::Output { + let precomputation = rhs.mul_precompute(); + rhs.mul_with_precomputation(self.0, precomputation) + } +} + +#[allow(clippy::assertions_on_constants)] +fn to_digits(x: &C::ScalarField) -> Vec { + debug_assert!( + 64 % WINDOW_BITS == 0, + "For simplicity, only power-of-two window sizes are handled for now" + ); + let digits_per_u64 = 64 / WINDOW_BITS; + let mut digits = Vec::with_capacity(digits_per_scalar::()); + for limb in x.to_canonical_biguint().to_u64_digits() { + for j in 0..digits_per_u64 { + digits.push((limb >> (j * WINDOW_BITS) as u64) % BASE as u64); + } + } + + digits +} diff --git a/src/curve/curve_summation.rs b/src/curve/curve_summation.rs new file mode 100644 index 0000000..7bb633a --- /dev/null +++ b/src/curve/curve_summation.rs @@ -0,0 +1,238 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::iter::Sum; + +use plonky2::field::ops::Square; +use plonky2::field::types::Field; + +use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + +impl Sum> for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + let points: Vec<_> = iter.collect(); + affine_summation_best(points) + } +} + +impl Sum for ProjectivePoint { + fn sum>>(iter: I) -> ProjectivePoint { + iter.fold(ProjectivePoint::ZERO, |acc, x| acc + x) + } +} + +pub fn affine_summation_best(summation: Vec>) -> ProjectivePoint { + let result = affine_multisummation_best(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +pub fn affine_multisummation_best( + summations: Vec>>, +) -> Vec> { + let pairwise_sums: usize = summations.iter().map(|summation| summation.len() / 2).sum(); + + // This threshold is chosen based on data from the summation benchmarks. + if pairwise_sums < 70 { + affine_multisummation_pairwise(summations) + } else { + affine_multisummation_batch_inversion(summations) + } +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_multisummation_pairwise( + summations: Vec>>, +) -> Vec> { + summations + .into_iter() + .map(affine_summation_pairwise) + .collect() +} + +/// Adds each pair of points using an affine + affine = projective formula, then adds up the +/// intermediate sums using a projective formula. +pub fn affine_summation_pairwise(points: Vec>) -> ProjectivePoint { + let mut reduced_points: Vec> = Vec::new(); + for chunk in points.chunks(2) { + match chunk.len() { + 1 => reduced_points.push(chunk[0].to_projective()), + 2 => reduced_points.push(chunk[0] + chunk[1]), + _ => panic!(), + } + } + // TODO: Avoid copying (deref) + reduced_points + .iter() + .fold(ProjectivePoint::ZERO, |sum, x| sum + *x) +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_summation_batch_inversion( + summation: Vec>, +) -> ProjectivePoint { + let result = affine_multisummation_batch_inversion(vec![summation]); + debug_assert_eq!(result.len(), 1); + result[0] +} + +/// Computes several summations of affine points by applying an affine group law, except that the +/// divisions are batched via Montgomery's trick. +pub fn affine_multisummation_batch_inversion( + summations: Vec>>, +) -> Vec> { + let mut elements_to_invert = Vec::new(); + + // For each pair of points, (x1, y1) and (x2, y2), that we're going to add later, we want to + // invert either y (if the points are equal) or x1 - x2 (otherwise). We will use these later. + for summation in &summations { + let n = summation.len(); + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: _y2, + zero: zero2, + } = p2; + + if zero1 || zero2 || p1 == -p2 { + // These are trivial cases where we won't need any inverse. + } else if p1 == p2 { + elements_to_invert.push(y1.double()); + } else { + elements_to_invert.push(x1 - x2); + } + } + } + + let inverses: Vec = + C::BaseField::batch_multiplicative_inverse(&elements_to_invert); + + let mut all_reduced_points = Vec::with_capacity(summations.len()); + let mut inverse_index = 0; + for summation in summations { + let n = summation.len(); + let mut reduced_points = Vec::with_capacity((n + 1) / 2); + + // The special case for n=0 is to avoid underflow. + let range_end = if n == 0 { 0 } else { n - 1 }; + + for i in (0..range_end).step_by(2) { + let p1 = summation[i]; + let p2 = summation[i + 1]; + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = p1; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = p2; + + let sum = if zero1 { + p2 + } else if zero2 { + p1 + } else if p1 == -p2 { + AffinePoint::ZERO + } else { + // It's a non-trivial case where we need one of the inverses we computed earlier. + let inverse = inverses[inverse_index]; + inverse_index += 1; + + if p1 == p2 { + // This is the doubling case. + let mut numerator = x1.square().triple(); + if C::A.is_nonzero() { + numerator += C::A; + } + let quotient = numerator * inverse; + let x3 = quotient.square() - x1.double(); + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } else { + // This is the general case. We use the incomplete addition formulas 4.3 and 4.4. + let quotient = (y1 - y2) * inverse; + let x3 = quotient.square() - x1 - x2; + let y3 = quotient * (x1 - x3) - y1; + AffinePoint::nonzero(x3, y3) + } + }; + reduced_points.push(sum); + } + + // If n is odd, the last point was not part of a pair. + if n % 2 == 1 { + reduced_points.push(summation[n - 1]); + } + + all_reduced_points.push(reduced_points); + } + + // We should have consumed all of the inverses from the batch computation. + debug_assert_eq!(inverse_index, inverses.len()); + + // Recurse with our smaller set of points. + affine_multisummation_best(all_reduced_points) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_pairwise_affine_summation() { + let g_affine = Secp256K1::GENERATOR_AFFINE; + let g2_affine = (g_affine + g_affine).to_affine(); + let g3_affine = (g_affine + g_affine + g_affine).to_affine(); + let g2_proj = g2_affine.to_projective(); + let g3_proj = g3_affine.to_projective(); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine]), + g2_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g2_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![g_affine, g_affine, g_affine]), + g3_proj + ); + assert_eq!( + affine_summation_pairwise::(vec![]), + ProjectivePoint::ZERO + ); + } + + #[test] + fn test_pairwise_affine_summation_batch_inversion() { + let g = Secp256K1::GENERATOR_AFFINE; + let g_proj = g.to_projective(); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g]), + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![g, g, g]), + g_proj + g_proj + g_proj + ); + assert_eq!( + affine_summation_batch_inversion::(vec![]), + ProjectivePoint::ZERO + ); + } +} diff --git a/src/curve/curve_types.rs b/src/curve/curve_types.rs new file mode 100644 index 0000000..9104739 --- /dev/null +++ b/src/curve/curve_types.rs @@ -0,0 +1,286 @@ +use alloc::vec::Vec; +use core::fmt::Debug; +use core::hash::{Hash, Hasher}; +use core::ops::Neg; + +use plonky2::field::ops::Square; +use plonky2::field::types::{Field, PrimeField}; +use serde::{Deserialize, Serialize}; + +// To avoid implementation conflicts from associated types, +// see https://github.com/rust-lang/rust/issues/20400 +pub struct CurveScalar(pub ::ScalarField); + +/// A short Weierstrass curve. +pub trait Curve: 'static + Sync + Sized + Copy + Debug { + type BaseField: PrimeField; + type ScalarField: PrimeField; + + const A: Self::BaseField; + const B: Self::BaseField; + + const GENERATOR_AFFINE: AffinePoint; + + const GENERATOR_PROJECTIVE: ProjectivePoint = ProjectivePoint { + x: Self::GENERATOR_AFFINE.x, + y: Self::GENERATOR_AFFINE.y, + z: Self::BaseField::ONE, + }; + + fn convert(x: Self::ScalarField) -> CurveScalar { + CurveScalar(x) + } + + fn is_safe_curve() -> bool { + // Added additional check to prevent using vulnerabilties in case a discriminant is equal to 0. + (Self::A.cube().double().double() + Self::B.square().triple().triple().triple()) + .is_nonzero() + } +} + +/// A point on a short Weierstrass curve, represented in affine coordinates. +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub struct AffinePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub zero: bool, +} + +impl AffinePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ZERO, + zero: true, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField) -> Self { + let point = Self { x, y, zero: false }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, zero } = *self; + zero || y.square() == x.cube() + C::A * x + C::B + } + + pub fn to_projective(&self) -> ProjectivePoint { + let Self { x, y, zero } = *self; + let z = if zero { + C::BaseField::ZERO + } else { + C::BaseField::ONE + }; + + ProjectivePoint { x, y, z } + } + + pub fn batch_to_projective(affine_points: &[Self]) -> Vec> { + affine_points.iter().map(Self::to_projective).collect() + } + + #[must_use] + pub fn double(&self) -> Self { + let AffinePoint { x: x1, y: y1, zero } = *self; + + if zero { + return AffinePoint::ZERO; + } + + let double_y = y1.double(); + let inv_double_y = double_y.inverse(); // (2y)^(-1) + let triple_xx = x1.square().triple(); // 3x^2 + let lambda = (triple_xx + C::A) * inv_double_y; + let x3 = lambda.square() - self.x.double(); + let y3 = lambda * (x1 - x3) - y1; + + Self { + x: x3, + y: y3, + zero: false, + } + } +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + let AffinePoint { + x: x1, + y: y1, + zero: zero1, + } = *self; + let AffinePoint { + x: x2, + y: y2, + zero: zero2, + } = *other; + if zero1 || zero2 { + return zero1 == zero2; + } + x1 == x2 && y1 == y2 + } +} + +impl Eq for AffinePoint {} + +impl Hash for AffinePoint { + fn hash(&self, state: &mut H) { + if self.zero { + self.zero.hash(state); + } else { + self.x.hash(state); + self.y.hash(state); + } + } +} + +/// A point on a short Weierstrass curve, represented in projective coordinates. +#[derive(Copy, Clone, Debug)] +pub struct ProjectivePoint { + pub x: C::BaseField, + pub y: C::BaseField, + pub z: C::BaseField, +} + +impl ProjectivePoint { + pub const ZERO: Self = Self { + x: C::BaseField::ZERO, + y: C::BaseField::ONE, + z: C::BaseField::ZERO, + }; + + pub fn nonzero(x: C::BaseField, y: C::BaseField, z: C::BaseField) -> Self { + let point = Self { x, y, z }; + debug_assert!(point.is_valid()); + point + } + + pub fn is_valid(&self) -> bool { + let Self { x, y, z } = *self; + z.is_zero() || y.square() * z == x.cube() + C::A * x * z.square() + C::B * z.cube() + } + + pub fn to_affine(&self) -> AffinePoint { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { + AffinePoint::ZERO + } else { + let z_inv = z.inverse(); + AffinePoint::nonzero(x * z_inv, y * z_inv) + } + } + + pub fn batch_to_affine(proj_points: &[Self]) -> Vec> { + let n = proj_points.len(); + let zs: Vec = proj_points.iter().map(|pp| pp.z).collect(); + let z_invs = C::BaseField::batch_multiplicative_inverse(&zs); + + let mut result = Vec::with_capacity(n); + for i in 0..n { + let Self { x, y, z } = proj_points[i]; + result.push(if z == C::BaseField::ZERO { + AffinePoint::ZERO + } else { + let z_inv = z_invs[i]; + AffinePoint::nonzero(x * z_inv, y * z_inv) + }); + } + result + } + + // From https://www.hyperelliptic.org/EFD/g1p/data/shortw/projective/doubling/dbl-2007-bl + #[must_use] + pub fn double(&self) -> Self { + let Self { x, y, z } = *self; + if z == C::BaseField::ZERO { + return ProjectivePoint::ZERO; + } + + let xx = x.square(); + let zz = z.square(); + let mut w = xx.triple(); + if C::A.is_nonzero() { + w += C::A * zz; + } + let s = y.double() * z; + let r = y * s; + let rr = r.square(); + let b = (x + r).square() - (xx + rr); + let h = w.square() - b.double(); + let x3 = h * s; + let y3 = w * (b - h) - rr.double(); + let z3 = s.cube(); + Self { + x: x3, + y: y3, + z: z3, + } + } + + pub fn add_slices(a: &[Self], b: &[Self]) -> Vec { + assert_eq!(a.len(), b.len()); + a.iter() + .zip(b.iter()) + .map(|(&a_i, &b_i)| a_i + b_i) + .collect() + } + + #[must_use] + pub fn neg(&self) -> Self { + Self { + x: self.x, + y: -self.y, + z: self.z, + } + } +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + let ProjectivePoint { + x: x1, + y: y1, + z: z1, + } = *self; + let ProjectivePoint { + x: x2, + y: y2, + z: z2, + } = *other; + if z1 == C::BaseField::ZERO || z2 == C::BaseField::ZERO { + return z1 == z2; + } + + // We want to compare (x1/z1, y1/z1) == (x2/z2, y2/z2). + // But to avoid field division, it is better to compare (x1*z2, y1*z2) == (x2*z1, y2*z1). + x1 * z2 == x2 * z1 && y1 * z2 == y2 * z1 + } +} + +impl Eq for ProjectivePoint {} + +impl Neg for AffinePoint { + type Output = AffinePoint; + + fn neg(self) -> Self::Output { + let AffinePoint { x, y, zero } = self; + AffinePoint { x, y: -y, zero } + } +} + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> Self::Output { + let ProjectivePoint { x, y, z } = self; + ProjectivePoint { x, y: -y, z } + } +} + +pub fn base_to_scalar(x: C::BaseField) -> C::ScalarField { + C::ScalarField::from_noncanonical_biguint(x.to_canonical_biguint()) +} + +pub fn scalar_to_base(x: C::ScalarField) -> C::BaseField { + C::BaseField::from_noncanonical_biguint(x.to_canonical_biguint()) +} diff --git a/src/curve/ecdsa.rs b/src/curve/ecdsa.rs new file mode 100644 index 0000000..131d8b4 --- /dev/null +++ b/src/curve/ecdsa.rs @@ -0,0 +1,84 @@ +use plonky2::field::types::{Field, Sample}; +use serde::{Deserialize, Serialize}; + +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{base_to_scalar, AffinePoint, Curve, CurveScalar}; + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSASignature { + pub r: C::ScalarField, + pub s: C::ScalarField, +} + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSASecretKey(pub C::ScalarField); + +impl ECDSASecretKey { + pub fn to_public(&self) -> ECDSAPublicKey { + ECDSAPublicKey((CurveScalar(self.0) * C::GENERATOR_PROJECTIVE).to_affine()) + } +} + +#[derive(Copy, Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct ECDSAPublicKey(pub AffinePoint); + +pub fn sign_message(msg: C::ScalarField, sk: ECDSASecretKey) -> ECDSASignature { + let (k, rr) = { + let mut k = C::ScalarField::rand(); + let mut rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + while rr.x == C::BaseField::ZERO { + k = C::ScalarField::rand(); + rr = (CurveScalar(k) * C::GENERATOR_PROJECTIVE).to_affine(); + } + (k, rr) + }; + let r = base_to_scalar::(rr.x); + + let s = k.inverse() * (msg + r * sk.0); + + ECDSASignature { r, s } +} + +pub fn verify_message( + msg: C::ScalarField, + sig: ECDSASignature, + pk: ECDSAPublicKey, +) -> bool { + let ECDSASignature { r, s } = sig; + + assert!(pk.0.is_valid()); + + let c = s.inverse(); + let u1 = msg * c; + let u2 = r * c; + + let g = C::GENERATOR_PROJECTIVE; + let w = 5; // Experimentally fastest + let point_proj = msm_parallel(&[u1, u2], &[g, pk.0.to_projective()], w); + let point = point_proj.to_affine(); + + let x = base_to_scalar::(point.x); + r == x +} + +#[cfg(test)] +mod tests { + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::Sample; + + use crate::curve::ecdsa::{sign_message, verify_message, ECDSASecretKey}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_ecdsa_native() { + type C = Secp256K1; + + let msg = Secp256K1Scalar::rand(); + let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); + let pk = sk.to_public(); + + let sig = sign_message(msg, sk); + let result = verify_message(msg, sig, pk); + assert!(result); + } +} diff --git a/src/curve/glv.rs b/src/curve/glv.rs new file mode 100644 index 0000000..7c3e5de --- /dev/null +++ b/src/curve/glv.rs @@ -0,0 +1,140 @@ +use num::rational::Ratio; +use num::BigUint; +use plonky2::field::secp256k1_base::Secp256K1Base; +use plonky2::field::secp256k1_scalar::Secp256K1Scalar; +use plonky2::field::types::{Field, PrimeField}; + +use crate::curve::curve_msm::msm_parallel; +use crate::curve::curve_types::{AffinePoint, ProjectivePoint}; +use crate::curve::secp256k1::Secp256K1; + +pub const GLV_BETA: Secp256K1Base = Secp256K1Base([ + 13923278643952681454, + 11308619431505398165, + 7954561588662645993, + 8856726876819556112, +]); + +pub const GLV_S: Secp256K1Scalar = Secp256K1Scalar([ + 16069571880186789234, + 1310022930574435960, + 11900229862571533402, + 6008836872998760672, +]); + +const A1: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +const MINUS_B1: Secp256K1Scalar = + Secp256K1Scalar([8022177200260244675, 16448129721693014056, 0, 0]); + +const A2: Secp256K1Scalar = Secp256K1Scalar([6323353552219852760, 1498098850674701302, 1, 0]); + +const B2: Secp256K1Scalar = Secp256K1Scalar([16747920425669159701, 3496713202691238861, 0, 0]); + +/// Algorithm 15.41 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. +/// Decompose a scalar `k` into two small scalars `k1, k2` with `|k1|, |k2| < √p` that satisfy +/// `k1 + s * k2 = k`. +/// Returns `(|k1|, |k2|, k1 < 0, k2 < 0)`. +pub fn decompose_secp256k1_scalar( + k: Secp256K1Scalar, +) -> (Secp256K1Scalar, Secp256K1Scalar, bool, bool) { + let p = Secp256K1Scalar::order(); + let c1_biguint = Ratio::new( + B2.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); + let c1 = Secp256K1Scalar::from_noncanonical_biguint(c1_biguint); + let c2_biguint = Ratio::new( + MINUS_B1.to_canonical_biguint() * k.to_canonical_biguint(), + p.clone(), + ) + .round() + .to_integer(); + let c2 = Secp256K1Scalar::from_noncanonical_biguint(c2_biguint); + + let k1_raw = k - c1 * A1 - c2 * A2; + let k2_raw = c1 * MINUS_B1 - c2 * B2; + debug_assert!(k1_raw + GLV_S * k2_raw == k); + + let two = BigUint::from_slice(&[2]); + let k1_neg = k1_raw.to_canonical_biguint() > p.clone() / two.clone(); + let k1 = if k1_neg { + Secp256K1Scalar::from_noncanonical_biguint(p.clone() - k1_raw.to_canonical_biguint()) + } else { + k1_raw + }; + let k2_neg = k2_raw.to_canonical_biguint() > p.clone() / two; + let k2 = if k2_neg { + Secp256K1Scalar::from_noncanonical_biguint(p - k2_raw.to_canonical_biguint()) + } else { + k2_raw + }; + + (k1, k2, k1_neg, k2_neg) +} + +/// See Section 15.2.1 in Handbook of Elliptic and Hyperelliptic Curve Cryptography. +/// GLV scalar multiplication `k * P = k1 * P + k2 * psi(P)`, where `k = k1 + s * k2` is the +/// decomposition computed in `decompose_secp256k1_scalar(k)` and `psi` is the Secp256k1 +/// endomorphism `psi: (x, y) |-> (beta * x, y)` equivalent to scalar multiplication by `s`. +pub fn glv_mul(p: ProjectivePoint, k: Secp256K1Scalar) -> ProjectivePoint { + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + + let p_affine = p.to_affine(); + let sp = AffinePoint:: { + x: p_affine.x * GLV_BETA, + y: p_affine.y, + zero: p_affine.zero, + }; + + let first = if k1_neg { p.neg() } else { p }; + let second = if k2_neg { + sp.to_projective().neg() + } else { + sp.to_projective() + }; + + msm_parallel(&[k1, k2], &[first, second], 5) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::{Field, Sample}; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::{decompose_secp256k1_scalar, glv_mul, GLV_S}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_glv_decompose() -> Result<()> { + let k = Secp256K1Scalar::rand(); + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + let one = Secp256K1Scalar::ONE; + let m1 = if k1_neg { -one } else { one }; + let m2 = if k2_neg { -one } else { one }; + + assert!(k1 * m1 + GLV_S * k2 * m2 == k); + + Ok(()) + } + + #[test] + fn test_glv_mul() -> Result<()> { + for _ in 0..20 { + let k = Secp256K1Scalar::rand(); + + let p = CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE; + + let kp = CurveScalar(k) * p; + let glv = glv_mul(p, k); + + assert!(kp == glv); + } + + Ok(()) + } +} diff --git a/src/curve/mod.rs b/src/curve/mod.rs new file mode 100644 index 0000000..1984b0c --- /dev/null +++ b/src/curve/mod.rs @@ -0,0 +1,8 @@ +pub mod curve_adds; +pub mod curve_msm; +pub mod curve_multiplication; +pub mod curve_summation; +pub mod curve_types; +pub mod ecdsa; +pub mod glv; +pub mod secp256k1; diff --git a/src/curve/secp256k1.rs b/src/curve/secp256k1.rs new file mode 100644 index 0000000..0b899a7 --- /dev/null +++ b/src/curve/secp256k1.rs @@ -0,0 +1,100 @@ +use plonky2::field::secp256k1_base::Secp256K1Base; +use plonky2::field::secp256k1_scalar::Secp256K1Scalar; +use plonky2::field::types::Field; +use serde::{Deserialize, Serialize}; + +use crate::curve::curve_types::{AffinePoint, Curve}; + +#[derive(Debug, Copy, Clone, Deserialize, Eq, Hash, PartialEq, Serialize)] +pub struct Secp256K1; + +impl Curve for Secp256K1 { + type BaseField = Secp256K1Base; + type ScalarField = Secp256K1Scalar; + + const A: Secp256K1Base = Secp256K1Base::ZERO; + const B: Secp256K1Base = Secp256K1Base([7, 0, 0, 0]); + const GENERATOR_AFFINE: AffinePoint = AffinePoint { + x: SECP256K1_GENERATOR_X, + y: SECP256K1_GENERATOR_Y, + zero: false, + }; +} + +// 55066263022277343669578718895168534326250603453777594175500187360389116729240 +const SECP256K1_GENERATOR_X: Secp256K1Base = Secp256K1Base([ + 0x59F2815B16F81798, + 0x029BFCDB2DCE28D9, + 0x55A06295CE870B07, + 0x79BE667EF9DCBBAC, +]); + +/// 32670510020758816978083085130507043184471273380659243275938904335757337482424 +const SECP256K1_GENERATOR_Y: Secp256K1Base = Secp256K1Base([ + 0x9C47D08FFB10D4B8, + 0xFD17B448A6855419, + 0x5DA4FBFC0E1108A8, + 0x483ADA7726A3C465, +]); + +#[cfg(test)] +mod tests { + use num::BigUint; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::{Field, PrimeField}; + + use crate::curve::curve_types::{AffinePoint, Curve, ProjectivePoint}; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_generator() { + let g = Secp256K1::GENERATOR_AFFINE; + assert!(g.is_valid()); + + let neg_g = AffinePoint:: { + x: g.x, + y: -g.y, + zero: g.zero, + }; + assert!(neg_g.is_valid()); + } + + #[test] + fn test_naive_multiplication() { + let g = Secp256K1::GENERATOR_PROJECTIVE; + let ten = Secp256K1Scalar::from_canonical_u64(10); + let product = mul_naive(ten, g); + let sum = g + g + g + g + g + g + g + g + g + g; + assert_eq!(product, sum); + } + + #[test] + fn test_g1_multiplication() { + let lhs = Secp256K1Scalar::from_noncanonical_biguint(BigUint::from_slice(&[ + 1111, 2222, 3333, 4444, 5555, 6666, 7777, 8888, + ])); + assert_eq!( + Secp256K1::convert(lhs) * Secp256K1::GENERATOR_PROJECTIVE, + mul_naive(lhs, Secp256K1::GENERATOR_PROJECTIVE) + ); + } + + /// A simple, somewhat inefficient implementation of multiplication which is used as a reference + /// for correctness. + fn mul_naive( + lhs: Secp256K1Scalar, + rhs: ProjectivePoint, + ) -> ProjectivePoint { + let mut g = rhs; + let mut sum = ProjectivePoint::ZERO; + for limb in lhs.to_canonical_biguint().to_u64_digits().iter() { + for j in 0..64 { + if (limb >> j & 1u64) != 0u64 { + sum = sum + g; + } + g = g.double(); + } + } + sum + } +} diff --git a/src/gadgets/biguint.rs b/src/gadgets/biguint.rs new file mode 100644 index 0000000..59e48d0 --- /dev/null +++ b/src/gadgets/biguint.rs @@ -0,0 +1,508 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::marker::PhantomData; + +use num::{BigUint, Integer, Zero}; +use plonky2::field::extension::Extendable; +use plonky2::field::types::{PrimeField, PrimeField64}; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartitionWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; +use plonky2_u32::gadgets::multiple_comparison::list_le_u32_circuit; +use plonky2_u32::witness::{GeneratedValuesU32, WitnessU32}; + +#[derive(Clone, Debug)] +pub struct BigUintTarget { + pub limbs: Vec, +} + +impl BigUintTarget { + pub fn num_limbs(&self) -> usize { + self.limbs.len() + } + + pub fn get_limb(&self, i: usize) -> U32Target { + self.limbs[i] + } +} + +pub trait CircuitBuilderBiguint, const D: usize> { + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget; + + fn zero_biguint(&mut self) -> BigUintTarget; + + fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget); + + fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget); + + fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget; + + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget; + + /// Add two `BigUintTarget`s. + fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + /// Subtract two `BigUintTarget`s. We assume that the first is larger than the second. + fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget; + + /// Returns x * y + z. This is no more efficient than mul-then-add; it's purely for convenience (only need to call one CircuitBuilder function). + fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget; + + fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget); + + fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; + + fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget; +} + +impl, const D: usize> CircuitBuilderBiguint + for CircuitBuilder +{ + fn constant_biguint(&mut self, value: &BigUint) -> BigUintTarget { + let limb_values = value.to_u32_digits(); + let limbs = limb_values.iter().map(|&l| self.constant_u32(l)).collect(); + + BigUintTarget { limbs } + } + + fn zero_biguint(&mut self) -> BigUintTarget { + self.constant_biguint(&BigUint::zero()) + } + + fn connect_biguint(&mut self, lhs: &BigUintTarget, rhs: &BigUintTarget) { + let min_limbs = lhs.num_limbs().min(rhs.num_limbs()); + for i in 0..min_limbs { + self.connect_u32(lhs.get_limb(i), rhs.get_limb(i)); + } + + for i in min_limbs..lhs.num_limbs() { + self.assert_zero_u32(lhs.get_limb(i)); + } + for i in min_limbs..rhs.num_limbs() { + self.assert_zero_u32(rhs.get_limb(i)); + } + } + + fn pad_biguints( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + if a.num_limbs() > b.num_limbs() { + let mut padded_b = b.clone(); + for _ in b.num_limbs()..a.num_limbs() { + padded_b.limbs.push(self.zero_u32()); + } + + (a.clone(), padded_b) + } else { + let mut padded_a = a.clone(); + for _ in a.num_limbs()..b.num_limbs() { + padded_a.limbs.push(self.zero_u32()); + } + + (padded_a, b.clone()) + } + } + + fn cmp_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BoolTarget { + let (a, b) = self.pad_biguints(a, b); + + list_le_u32_circuit(self, a.limbs, b.limbs) + } + + fn add_virtual_biguint_target(&mut self, num_limbs: usize) -> BigUintTarget { + let limbs = self.add_virtual_u32_targets(num_limbs); + + BigUintTarget { limbs } + } + + fn add_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let num_limbs = a.num_limbs().max(b.num_limbs()); + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for i in 0..num_limbs { + let a_limb = (i < a.num_limbs()) + .then(|| a.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + let b_limb = (i < b.num_limbs()) + .then(|| b.limbs[i]) + .unwrap_or_else(|| self.zero_u32()); + + let (new_limb, new_carry) = self.add_many_u32(&[carry, a_limb, b_limb]); + carry = new_carry; + combined_limbs.push(new_limb); + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + fn sub_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (a, b) = self.pad_biguints(a, b); + let num_limbs = a.limbs.len(); + + let mut result_limbs = vec![]; + + let mut borrow = self.zero_u32(); + for i in 0..num_limbs { + let (result, new_borrow) = self.sub_u32(a.limbs[i], b.limbs[i], borrow); + result_limbs.push(result); + borrow = new_borrow; + } + // Borrow should be zero here. + + BigUintTarget { + limbs: result_limbs, + } + } + + fn mul_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let total_limbs = a.limbs.len() + b.limbs.len(); + + let mut to_add = vec![vec![]; total_limbs]; + for i in 0..a.limbs.len() { + for j in 0..b.limbs.len() { + let (product, carry) = self.mul_u32(a.limbs[i], b.limbs[j]); + to_add[i + j].push(product); + to_add[i + j + 1].push(carry); + } + } + + let mut combined_limbs = vec![]; + let mut carry = self.zero_u32(); + for summands in &mut to_add { + let (new_result, new_carry) = self.add_u32s_with_carry(summands, carry); + combined_limbs.push(new_result); + carry = new_carry; + } + combined_limbs.push(carry); + + BigUintTarget { + limbs: combined_limbs, + } + } + + fn mul_biguint_by_bool(&mut self, a: &BigUintTarget, b: BoolTarget) -> BigUintTarget { + let t = b.target; + + BigUintTarget { + limbs: a + .limbs + .iter() + .map(|&l| U32Target(self.mul(l.0, t))) + .collect(), + } + } + + fn mul_add_biguint( + &mut self, + x: &BigUintTarget, + y: &BigUintTarget, + z: &BigUintTarget, + ) -> BigUintTarget { + let prod = self.mul_biguint(x, y); + self.add_biguint(&prod, z) + } + + fn div_rem_biguint( + &mut self, + a: &BigUintTarget, + b: &BigUintTarget, + ) -> (BigUintTarget, BigUintTarget) { + let a_len = a.limbs.len(); + let b_len = b.limbs.len(); + let div_num_limbs = if b_len > a_len + 1 { + 0 + } else { + a_len - b_len + 1 + }; + let div = self.add_virtual_biguint_target(div_num_limbs); + let rem = self.add_virtual_biguint_target(b_len); + + self.add_simple_generator(BigUintDivRemGenerator:: { + a: a.clone(), + b: b.clone(), + div: div.clone(), + rem: rem.clone(), + _phantom: PhantomData, + }); + + let div_b = self.mul_biguint(&div, b); + let div_b_plus_rem = self.add_biguint(&div_b, &rem); + self.connect_biguint(a, &div_b_plus_rem); + + let cmp_rem_b = self.cmp_biguint(&rem, b); + self.assert_one(cmp_rem_b.target); + + (div, rem) + } + + fn div_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (div, _rem) = self.div_rem_biguint(a, b); + div + } + + fn rem_biguint(&mut self, a: &BigUintTarget, b: &BigUintTarget) -> BigUintTarget { + let (_div, rem) = self.div_rem_biguint(a, b); + rem + } +} + +pub trait WitnessBigUint: Witness { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint; + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); +} + +impl, F: PrimeField64> WitnessBigUint for T { + fn get_biguint_target(&self, target: BigUintTarget) -> BigUint { + target + .limbs + .into_iter() + .rev() + .fold(BigUint::zero(), |acc, limb| { + (acc << 32) + self.get_target(limb.0).to_canonical_biguint() + }) + } + + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.limbs[i], limbs[i]); + } + } +} + +pub trait GeneratedValuesBigUint { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint); +} + +impl GeneratedValuesBigUint for GeneratedValues { + fn set_biguint_target(&mut self, target: &BigUintTarget, value: &BigUint) { + let mut limbs = value.to_u32_digits(); + assert!(target.num_limbs() >= limbs.len()); + limbs.resize(target.num_limbs(), 0); + for i in 0..target.num_limbs() { + self.set_u32_target(target.get_limb(i), limbs[i]); + } + } +} + +#[derive(Debug)] +struct BigUintDivRemGenerator, const D: usize> { + a: BigUintTarget, + b: BigUintTarget, + div: BigUintTarget, + rem: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for BigUintDivRemGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .limbs + .iter() + .chain(&self.b.limbs) + .map(|&l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = witness.get_biguint_target(self.a.clone()); + let b = witness.get_biguint_target(self.b.clone()); + let (div, rem) = a.div_rem(&b); + + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.rem, &rem); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::{BigUint, FromPrimitive, Integer}; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use crate::gadgets::biguint::{CircuitBuilderBiguint, WitnessBigUint}; + + #[test] + fn test_biguint_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = OsRng; + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value + &y_value; + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); + let z = builder.add_biguint(&x, &y); + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); + builder.connect_biguint(&z, &expected_z); + + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_sub() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = OsRng; + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let expected_z_value = &x_value - &y_value; + + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let z = builder.sub_biguint(&x, &y); + let expected_z = builder.constant_biguint(&expected_z_value); + + builder.connect_biguint(&z, &expected_z); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = OsRng; + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + let expected_z_value = &x_value * &y_value; + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.add_virtual_biguint_target(x_value.to_u32_digits().len()); + let y = builder.add_virtual_biguint_target(y_value.to_u32_digits().len()); + let z = builder.mul_biguint(&x, &y); + let expected_z = builder.add_virtual_biguint_target(expected_z_value.to_u32_digits().len()); + builder.connect_biguint(&z, &expected_z); + + pw.set_biguint_target(&x, &x_value); + pw.set_biguint_target(&y, &y_value); + pw.set_biguint_target(&expected_z, &expected_z_value); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_cmp() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = OsRng; + + let x_value = BigUint::from_u128(rng.gen()).unwrap(); + let y_value = BigUint::from_u128(rng.gen()).unwrap(); + + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let cmp = builder.cmp_biguint(&x, &y); + let expected_cmp = builder.constant_bool(x_value <= y_value); + + builder.connect(cmp.target, expected_cmp.target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_biguint_div_rem() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let mut rng = OsRng; + + let mut x_value = BigUint::from_u128(rng.gen()).unwrap(); + let mut y_value = BigUint::from_u128(rng.gen()).unwrap(); + if y_value > x_value { + (x_value, y_value) = (y_value, x_value); + } + let (expected_div_value, expected_rem_value) = x_value.div_rem(&y_value); + + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_biguint(&x_value); + let y = builder.constant_biguint(&y_value); + let (div, rem) = builder.div_rem_biguint(&x, &y); + + let expected_div = builder.constant_biguint(&expected_div_value); + let expected_rem = builder.constant_biguint(&expected_rem_value); + + builder.connect_biguint(&div, &expected_div); + builder.connect_biguint(&rem, &expected_rem); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/src/gadgets/curve.rs b/src/gadgets/curve.rs new file mode 100644 index 0000000..1107532 --- /dev/null +++ b/src/gadgets/curve.rs @@ -0,0 +1,486 @@ +use alloc::vec; +use alloc::vec::Vec; + +use plonky2::field::extension::Extendable; +use plonky2::field::types::Sample; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::BoolTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +/// A Target representing an affine point on the curve `C`. We use incomplete arithmetic for efficiency, +/// so we assume these points are not zero. +#[derive(Clone, Debug)] +pub struct AffinePointTarget { + pub x: NonNativeTarget, + pub y: NonNativeTarget, +} + +impl AffinePointTarget { + pub fn to_vec(&self) -> Vec> { + vec![self.x.clone(), self.y.clone()] + } +} + +pub trait CircuitBuilderCurve, const D: usize> { + fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget; + + fn connect_affine_point( + &mut self, + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, + ); + + fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget; + + fn curve_assert_valid(&mut self, p: &AffinePointTarget); + + fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget; + + fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget; + + fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget; + + fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget; + + /// Add two points, which are assumed to be non-equal. + fn curve_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget; + + fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget; + + fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderCurve + for CircuitBuilder +{ + fn constant_affine_point(&mut self, point: AffinePoint) -> AffinePointTarget { + debug_assert!(!point.zero); + AffinePointTarget { + x: self.constant_nonnative(point.x), + y: self.constant_nonnative(point.y), + } + } + + fn connect_affine_point( + &mut self, + lhs: &AffinePointTarget, + rhs: &AffinePointTarget, + ) { + self.connect_nonnative(&lhs.x, &rhs.x); + self.connect_nonnative(&lhs.y, &rhs.y); + } + + fn add_virtual_affine_point_target(&mut self) -> AffinePointTarget { + let x = self.add_virtual_nonnative_target(); + let y = self.add_virtual_nonnative_target(); + + AffinePointTarget { x, y } + } + + fn curve_assert_valid(&mut self, p: &AffinePointTarget) { + let a = self.constant_nonnative(C::A); + let b = self.constant_nonnative(C::B); + + let y_squared = self.mul_nonnative(&p.y, &p.y); + let x_squared = self.mul_nonnative(&p.x, &p.x); + let x_cubed = self.mul_nonnative(&x_squared, &p.x); + let a_x = self.mul_nonnative(&a, &p.x); + let a_x_plus_b = self.add_nonnative(&a_x, &b); + let rhs = self.add_nonnative(&x_cubed, &a_x_plus_b); + + self.connect_nonnative(&y_squared, &rhs); + } + + fn curve_neg(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let neg_y = self.neg_nonnative(&p.y); + AffinePointTarget { + x: p.x.clone(), + y: neg_y, + } + } + + fn curve_conditional_neg( + &mut self, + p: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + AffinePointTarget { + x: p.x.clone(), + y: self.nonnative_conditional_neg(&p.y, b), + } + } + + fn curve_double(&mut self, p: &AffinePointTarget) -> AffinePointTarget { + let AffinePointTarget { x, y } = p; + let double_y = self.add_nonnative(y, y); + let inv_double_y = self.inv_nonnative(&double_y); + let x_squared = self.mul_nonnative(x, x); + let double_x_squared = self.add_nonnative(&x_squared, &x_squared); + let triple_x_squared = self.add_nonnative(&double_x_squared, &x_squared); + + let a = self.constant_nonnative(C::A); + let triple_xx_a = self.add_nonnative(&triple_x_squared, &a); + let lambda = self.mul_nonnative(&triple_xx_a, &inv_double_y); + let lambda_squared = self.mul_nonnative(&lambda, &lambda); + let x_double = self.add_nonnative(x, x); + + let x3 = self.sub_nonnative(&lambda_squared, &x_double); + + let x_diff = self.sub_nonnative(x, &x3); + let lambda_x_diff = self.mul_nonnative(&lambda, &x_diff); + + let y3 = self.sub_nonnative(&lambda_x_diff, y); + + AffinePointTarget { x: x3, y: y3 } + } + + fn curve_repeated_double( + &mut self, + p: &AffinePointTarget, + n: usize, + ) -> AffinePointTarget { + let mut result = p.clone(); + + for _ in 0..n { + result = self.curve_double(&result); + } + + result + } + + fn curve_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let AffinePointTarget { x: x1, y: y1 } = p1; + let AffinePointTarget { x: x2, y: y2 } = p2; + + let u = self.sub_nonnative(y2, y1); + let v = self.sub_nonnative(x2, x1); + let v_inv = self.inv_nonnative(&v); + let s = self.mul_nonnative(&u, &v_inv); + let s_squared = self.mul_nonnative(&s, &s); + let x_sum = self.add_nonnative(x2, x1); + let x3 = self.sub_nonnative(&s_squared, &x_sum); + let x_diff = self.sub_nonnative(x1, &x3); + let prod = self.mul_nonnative(&s, &x_diff); + let y3 = self.sub_nonnative(&prod, y1); + + AffinePointTarget { x: x3, y: y3 } + } + + fn curve_conditional_add( + &mut self, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + b: BoolTarget, + ) -> AffinePointTarget { + let not_b = self.not(b); + let sum = self.curve_add(p1, p2); + let x_if_true = self.mul_nonnative_by_bool(&sum.x, b); + let y_if_true = self.mul_nonnative_by_bool(&sum.y, b); + let x_if_false = self.mul_nonnative_by_bool(&p1.x, not_b); + let y_if_false = self.mul_nonnative_by_bool(&p1.y, not_b); + + let x = self.add_nonnative(&x_if_true, &x_if_false); + let y = self.add_nonnative(&y_if_true, &y_if_false); + + AffinePointTarget { x, y } + } + + fn curve_scalar_mul( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let bits = self.split_nonnative_to_bits(n); + + let rando = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let randot = self.constant_affine_point(rando); + // Result starts at `rando`, which is later subtracted, because we don't support arithmetic with the zero point. + let mut result = self.add_virtual_affine_point_target(); + self.connect_affine_point(&randot, &result); + + let mut two_i_times_p = self.add_virtual_affine_point_target(); + self.connect_affine_point(p, &two_i_times_p); + + for &bit in bits.iter() { + let not_bit = self.not(bit); + + let result_plus_2_i_p = self.curve_add(&result, &two_i_times_p); + + let new_x_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.x, bit); + let new_x_if_not_bit = self.mul_nonnative_by_bool(&result.x, not_bit); + let new_y_if_bit = self.mul_nonnative_by_bool(&result_plus_2_i_p.y, bit); + let new_y_if_not_bit = self.mul_nonnative_by_bool(&result.y, not_bit); + + let new_x = self.add_nonnative(&new_x_if_bit, &new_x_if_not_bit); + let new_y = self.add_nonnative(&new_y_if_bit, &new_y_if_not_bit); + + result = AffinePointTarget { x: new_x, y: new_y }; + + two_i_times_p = self.curve_double(&two_i_times_p); + } + + // Subtract off result's intial value of `rando`. + let neg_r = self.curve_neg(&randot); + result = self.curve_add(&result, &neg_r); + + result + } +} + +#[cfg(test)] +mod tests { + use core::ops::Neg; + + use anyhow::Result; + use plonky2::field::secp256k1_base::Secp256K1Base; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::{Field, Sample}; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + fn test_curve_point_is_valid() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + builder.curve_assert_valid(&g_target); + builder.curve_assert_valid(&neg_g_target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[should_panic] + fn test_curve_point_is_not_valid() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let not_g = AffinePoint:: { + x: g.x, + y: g.y + Secp256K1Base::ONE, + zero: g.zero, + }; + let not_g_target = builder.constant_affine_point(not_g); + + builder.curve_assert_valid(¬_g_target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof).unwrap() + } + + #[test] + fn test_curve_double() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let g_target = builder.constant_affine_point(g); + let neg_g_target = builder.curve_neg(&g_target); + + let double_g = g.double(); + let double_g_expected = builder.constant_affine_point(double_g); + builder.curve_assert_valid(&double_g_expected); + + let double_neg_g = (-g).double(); + let double_neg_g_expected = builder.constant_affine_point(double_neg_g); + builder.curve_assert_valid(&double_neg_g_expected); + + let double_g_actual = builder.curve_double(&g_target); + let double_neg_g_actual = builder.curve_double(&neg_g_target); + builder.curve_assert_valid(&double_g_actual); + builder.curve_assert_valid(&double_neg_g_actual); + + builder.connect_affine_point(&double_g_expected, &double_g_actual); + builder.connect_affine_point(&double_neg_g_expected, &double_neg_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + fn test_curve_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + builder.curve_assert_valid(&g_plus_2g_expected); + + let g_target = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_target); + let g_plus_2g_actual = builder.curve_add(&g_target, &double_g_target); + builder.curve_assert_valid(&g_plus_2g_actual); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + fn test_curve_conditional_add() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let double_g = g.double(); + let g_plus_2g = (g + double_g).to_affine(); + let g_plus_2g_expected = builder.constant_affine_point(g_plus_2g); + + let g_expected = builder.constant_affine_point(g); + let double_g_target = builder.curve_double(&g_expected); + let t = builder._true(); + let f = builder._false(); + let g_plus_2g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, t); + let g_actual = builder.curve_conditional_add(&g_expected, &double_g_target, f); + + builder.connect_affine_point(&g_plus_2g_expected, &g_plus_2g_actual); + builder.connect_affine_point(&g_expected, &g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_PROJECTIVE.to_affine(); + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_random() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let two_target = builder.constant_nonnative(Secp256K1Scalar::TWO); + let randot_doubled = builder.curve_double(&randot); + let randot_times_two = builder.curve_scalar_mul(&randot, &two_target); + builder.connect_affine_point(&randot_doubled, &randot_times_two); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/src/gadgets/curve_fixed_base.rs b/src/gadgets/curve_fixed_base.rs new file mode 100644 index 0000000..e7656f5 --- /dev/null +++ b/src/gadgets/curve_fixed_base.rs @@ -0,0 +1,118 @@ +use alloc::vec::Vec; + +use num::BigUint; +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; + +use crate::curve::curve_types::{AffinePoint, Curve, CurveScalar}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +/// Compute windowed fixed-base scalar multiplication, using a 4-bit window. +pub fn fixed_base_curve_mul_circuit, const D: usize>( + builder: &mut CircuitBuilder, + base: AffinePoint, + scalar: &NonNativeTarget, +) -> AffinePointTarget { + // Holds `(16^i) * base` for `i=0..scalar.value.limbs.len() * 8`. + let scaled_base = (0..scalar.value.limbs.len() * 8).scan(base, |acc, _| { + let tmp = *acc; + for _ in 0..4 { + *acc = acc.double(); + } + Some(tmp) + }); + + let limbs = builder.split_nonnative_to_4_bit_limbs(scalar); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + + let zero = builder.zero(); + let mut result = builder.constant_affine_point(rando); + // `s * P = sum s_i * P_i` with `P_i = (16^i) * P` and `s = sum s_i * (16^i)`. + for (limb, point) in limbs.into_iter().zip(scaled_base) { + // `muls_point[t] = t * P_i` for `t=0..16`. + let mut muls_point = (0..16) + .scan(AffinePoint::ZERO, |acc, _| { + let tmp = *acc; + *acc = (point + *acc).to_affine(); + Some(tmp) + }) + // First element if zero, so we skip it since `constant_affine_point` takes non-zero input. + .skip(1) + .map(|p| builder.constant_affine_point(p)) + .collect::>(); + // We add back a point in position 0. `limb == zero` is checked below, so this point can be arbitrary. + muls_point.insert(0, muls_point[0].clone()); + let is_zero = builder.is_equal(limb, zero); + let should_add = builder.not(is_zero); + // `r = s_i * P_i` + let r = builder.random_access_curve_points(limb, muls_point); + result = builder.curve_conditional_add(&result, &r, should_add); + } + + let to_add = builder.constant_affine_point(-rando); + builder.curve_add(&result, &to_add) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::{PrimeField, Sample}; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::biguint::WitnessBigUint; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_fixed_base() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = Secp256K1::GENERATOR_AFFINE; + let n = Secp256K1Scalar::rand(); + + let res = (CurveScalar(n) * g.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let n_target = builder.add_virtual_nonnative_target::(); + pw.set_biguint_target(&n_target.value, &n.to_canonical_biguint()); + + let res_target = fixed_base_curve_mul_circuit(&mut builder, g, &n_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/src/gadgets/curve_msm.rs b/src/gadgets/curve_msm.rs new file mode 100644 index 0000000..7bb4a6c --- /dev/null +++ b/src/gadgets/curve_msm.rs @@ -0,0 +1,138 @@ +use alloc::vec; + +use num::BigUint; +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_windowed_mul::CircuitBuilderWindowedMul; +use crate::gadgets::nonnative::NonNativeTarget; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +/// Computes `n*p + m*q` using windowed MSM, with a 2-bit window. +/// See Algorithm 9.23 in Handbook of Elliptic and Hyperelliptic Curve Cryptography for a +/// description. +/// Note: Doesn't work if `p == q`. +pub fn curve_msm_circuit, const D: usize>( + builder: &mut CircuitBuilder, + p: &AffinePointTarget, + q: &AffinePointTarget, + n: &NonNativeTarget, + m: &NonNativeTarget, +) -> AffinePointTarget { + let limbs_n = builder.split_nonnative_to_2_bit_limbs(n); + let limbs_m = builder.split_nonnative_to_2_bit_limbs(m); + assert_eq!(limbs_n.len(), limbs_m.len()); + let num_limbs = limbs_n.len(); + + let hash_0 = KeccakHash::<32>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let rando = (CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE).to_affine(); + let rando_t = builder.constant_affine_point(rando); + let neg_rando = builder.constant_affine_point(-rando); + + // Precomputes `precomputation[i + 4*j] = i*p + j*q` for `i,j=0..4`. + let mut precomputation = vec![p.clone(); 16]; + let mut cur_p = rando_t.clone(); + let mut cur_q = rando_t.clone(); + for i in 0..4 { + precomputation[i] = cur_p.clone(); + precomputation[4 * i] = cur_q.clone(); + cur_p = builder.curve_add(&cur_p, p); + cur_q = builder.curve_add(&cur_q, q); + } + for i in 1..4 { + precomputation[i] = builder.curve_add(&precomputation[i], &neg_rando); + precomputation[4 * i] = builder.curve_add(&precomputation[4 * i], &neg_rando); + } + for i in 1..4 { + for j in 1..4 { + precomputation[i + 4 * j] = + builder.curve_add(&precomputation[i], &precomputation[4 * j]); + } + } + + let four = builder.constant(F::from_canonical_usize(4)); + + let zero = builder.zero(); + let mut result = rando_t; + for (limb_n, limb_m) in limbs_n.into_iter().zip(limbs_m).rev() { + result = builder.curve_repeated_double(&result, 2); + let index = builder.mul_add(four, limb_m, limb_n); + let r = builder.random_access_curve_points(index, precomputation.clone()); + let is_zero = builder.is_equal(index, zero); + let should_add = builder.not(is_zero); + result = builder.curve_conditional_add(&result, &r, should_add); + } + let starting_point_multiplied = (0..2 * num_limbs).fold(rando, |acc, _| acc.double()); + let to_add = builder.constant_affine_point(-starting_point_multiplied); + result = builder.curve_add(&result, &to_add); + + result +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::Sample; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::curve_msm::curve_msm_circuit; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_curve_msm() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let p = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let q = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let n = Secp256K1Scalar::rand(); + let m = Secp256K1Scalar::rand(); + + let res = + (CurveScalar(n) * p.to_projective() + CurveScalar(m) * q.to_projective()).to_affine(); + let res_expected = builder.constant_affine_point(res); + builder.curve_assert_valid(&res_expected); + + let p_target = builder.constant_affine_point(p); + let q_target = builder.constant_affine_point(q); + let n_target = builder.constant_nonnative(n); + let m_target = builder.constant_nonnative(m); + + let res_target = + curve_msm_circuit(&mut builder, &p_target, &q_target, &n_target, &m_target); + builder.curve_assert_valid(&res_target); + + builder.connect_affine_point(&res_target, &res_expected); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/src/gadgets/curve_windowed_mul.rs b/src/gadgets/curve_windowed_mul.rs new file mode 100644 index 0000000..39fad17 --- /dev/null +++ b/src/gadgets/curve_windowed_mul.rs @@ -0,0 +1,254 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::marker::PhantomData; + +use num::BigUint; +use plonky2::field::extension::Extendable; +use plonky2::field::types::{Field, Sample}; +use plonky2::hash::hash_types::RichField; +use plonky2::hash::keccak::KeccakHash; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::config::{GenericHashOut, Hasher}; +use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; + +use crate::curve::curve_types::{Curve, CurveScalar}; +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; +use crate::gadgets::split_nonnative::CircuitBuilderSplit; + +const WINDOW_SIZE: usize = 4; + +pub trait CircuitBuilderWindowedMul, const D: usize> { + fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec>; + + fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget; + + fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget; + + fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderWindowedMul + for CircuitBuilder +{ + fn precompute_window( + &mut self, + p: &AffinePointTarget, + ) -> Vec> { + let g = (CurveScalar(C::ScalarField::rand()) * C::GENERATOR_PROJECTIVE).to_affine(); + let neg = { + let mut neg = g; + neg.y = -neg.y; + self.constant_affine_point(neg) + }; + + let mut multiples = vec![self.constant_affine_point(g)]; + for i in 1..1 << WINDOW_SIZE { + multiples.push(self.curve_add(p, &multiples[i - 1])); + } + for i in 1..1 << WINDOW_SIZE { + multiples[i] = self.curve_add(&neg, &multiples[i]); + } + multiples + } + + fn random_access_curve_points( + &mut self, + access_index: Target, + v: Vec>, + ) -> AffinePointTarget { + let num_limbs = C::BaseField::BITS / 32; + let zero = self.zero_u32(); + let x_limbs: Vec> = (0..num_limbs) + .map(|i| { + v.iter() + .map(|p| p.x.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) + .collect(); + let y_limbs: Vec> = (0..num_limbs) + .map(|i| { + v.iter() + .map(|p| p.y.value.limbs.get(i).unwrap_or(&zero).0) + .collect() + }) + .collect(); + + let selected_x_limbs: Vec<_> = x_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + let selected_y_limbs: Vec<_> = y_limbs + .iter() + .map(|limbs| U32Target(self.random_access(access_index, limbs.clone()))) + .collect(); + + let x = NonNativeTarget { + value: BigUintTarget { + limbs: selected_x_limbs, + }, + _phantom: PhantomData, + }; + let y = NonNativeTarget { + value: BigUintTarget { + limbs: selected_y_limbs, + }, + _phantom: PhantomData, + }; + AffinePointTarget { x, y } + } + + fn if_affine_point( + &mut self, + b: BoolTarget, + p1: &AffinePointTarget, + p2: &AffinePointTarget, + ) -> AffinePointTarget { + let new_x = self.if_nonnative(b, &p1.x, &p2.x); + let new_y = self.if_nonnative(b, &p1.y, &p2.y); + AffinePointTarget { x: new_x, y: new_y } + } + + fn curve_scalar_mul_windowed( + &mut self, + p: &AffinePointTarget, + n: &NonNativeTarget, + ) -> AffinePointTarget { + let hash_0 = KeccakHash::<25>::hash_no_pad(&[F::ZERO]); + let hash_0_scalar = C::ScalarField::from_noncanonical_biguint(BigUint::from_bytes_le( + &GenericHashOut::::to_bytes(&hash_0), + )); + let starting_point = CurveScalar(hash_0_scalar) * C::GENERATOR_PROJECTIVE; + let starting_point_multiplied = { + let mut cur = starting_point; + for _ in 0..C::ScalarField::BITS { + cur = cur.double(); + } + cur + }; + + let mut result = self.constant_affine_point(starting_point.to_affine()); + + let precomputation = self.precompute_window(p); + let zero = self.zero(); + + let windows = self.split_nonnative_to_4_bit_limbs(n); + for i in (0..windows.len()).rev() { + result = self.curve_repeated_double(&result, WINDOW_SIZE); + let window = windows[i]; + + let to_add = self.random_access_curve_points(window, precomputation.clone()); + let is_zero = self.is_equal(window, zero); + let should_add = self.not(is_zero); + result = self.curve_conditional_add(&result, &to_add, should_add); + } + + let to_subtract = self.constant_affine_point(starting_point_multiplied.to_affine()); + let to_add = self.curve_neg(&to_subtract); + result = self.curve_add(&result, &to_add); + + result + } +} + +#[cfg(test)] +mod tests { + use core::ops::Neg; + + use anyhow::Result; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + use crate::curve::secp256k1::Secp256K1; + + #[test] + fn test_random_access_curve_points() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let num_points = 16; + let points: Vec<_> = (0..num_points) + .map(|_| { + let g = (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE) + .to_affine(); + builder.constant_affine_point(g) + }) + .collect(); + + let mut rng = OsRng; + let access_index = rng.gen::() % num_points; + + let access_index_target = builder.constant(F::from_canonical_usize(access_index)); + let selected = builder.random_access_curve_points(access_index_target, points.clone()); + let expected = points[access_index].clone(); + builder.connect_affine_point(&selected, &expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } + + #[test] + #[ignore] + fn test_curve_windowed_mul() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let g = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let five = Secp256K1Scalar::from_canonical_usize(5); + let neg_five = five.neg(); + let neg_five_scalar = CurveScalar::(neg_five); + let neg_five_g = (neg_five_scalar * g.to_projective()).to_affine(); + let neg_five_g_expected = builder.constant_affine_point(neg_five_g); + builder.curve_assert_valid(&neg_five_g_expected); + + let g_target = builder.constant_affine_point(g); + let neg_five_target = builder.constant_nonnative(neg_five); + let neg_five_g_actual = builder.curve_scalar_mul_windowed(&g_target, &neg_five_target); + builder.curve_assert_valid(&neg_five_g_actual); + + builder.connect_affine_point(&neg_five_g_expected, &neg_five_g_actual); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/src/gadgets/ecdsa.rs b/src/gadgets/ecdsa.rs new file mode 100644 index 0000000..657ec49 --- /dev/null +++ b/src/gadgets/ecdsa.rs @@ -0,0 +1,111 @@ +use core::marker::PhantomData; + +use plonky2::field::extension::Extendable; +use plonky2::field::secp256k1_scalar::Secp256K1Scalar; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::curve::curve_types::Curve; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_fixed_base::fixed_base_curve_mul_circuit; +use crate::gadgets::glv::CircuitBuilderGlv; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +#[derive(Clone, Debug)] +pub struct ECDSASecretKeyTarget(pub NonNativeTarget); + +#[derive(Clone, Debug)] +pub struct ECDSAPublicKeyTarget(pub AffinePointTarget); + +#[derive(Clone, Debug)] +pub struct ECDSASignatureTarget { + pub r: NonNativeTarget, + pub s: NonNativeTarget, +} + +pub fn verify_message_circuit, const D: usize>( + builder: &mut CircuitBuilder, + msg: NonNativeTarget, + sig: ECDSASignatureTarget, + pk: ECDSAPublicKeyTarget, +) { + let ECDSASignatureTarget { r, s } = sig; + + builder.curve_assert_valid(&pk.0); + + let c = builder.inv_nonnative(&s); + let u1 = builder.mul_nonnative(&msg, &c); + let u2 = builder.mul_nonnative(&r, &c); + + let point1 = fixed_base_curve_mul_circuit(builder, Secp256K1::GENERATOR_AFFINE, &u1); + let point2 = builder.glv_mul(&pk.0, &u2); + let point = builder.curve_add(&point1, &point2); + + let x = NonNativeTarget:: { + value: point.x.value, + _phantom: PhantomData, + }; + builder.connect_nonnative(&r, &x); +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::types::Sample; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use super::*; + use crate::curve::curve_types::CurveScalar; + use crate::curve::ecdsa::{sign_message, ECDSAPublicKey, ECDSASecretKey, ECDSASignature}; + + fn test_ecdsa_circuit_with_config(config: CircuitConfig) -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + type Curve = Secp256K1; + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let msg = Secp256K1Scalar::rand(); + let msg_target = builder.constant_nonnative(msg); + + let sk = ECDSASecretKey::(Secp256K1Scalar::rand()); + let pk = ECDSAPublicKey((CurveScalar(sk.0) * Curve::GENERATOR_PROJECTIVE).to_affine()); + + let pk_target = ECDSAPublicKeyTarget(builder.constant_affine_point(pk.0)); + + let sig = sign_message(msg, sk); + + let ECDSASignature { r, s } = sig; + let r_target = builder.constant_nonnative(r); + let s_target = builder.constant_nonnative(s); + let sig_target = ECDSASignatureTarget { + r: r_target, + s: s_target, + }; + + verify_message_circuit(&mut builder, msg_target, sig_target, pk_target); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_narrow() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::standard_ecc_config()) + } + + #[test] + #[ignore] + fn test_ecdsa_circuit_wide() -> Result<()> { + test_ecdsa_circuit_with_config(CircuitConfig::wide_ecc_config()) + } +} diff --git a/src/gadgets/glv.rs b/src/gadgets/glv.rs new file mode 100644 index 0000000..8ffa9c8 --- /dev/null +++ b/src/gadgets/glv.rs @@ -0,0 +1,180 @@ +use alloc::vec::Vec; +use core::marker::PhantomData; + +use plonky2::field::extension::Extendable; +use plonky2::field::secp256k1_base::Secp256K1Base; +use plonky2::field::secp256k1_scalar::Secp256K1Scalar; +use plonky2::field::types::{Field, PrimeField}; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartitionWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::curve::glv::{decompose_secp256k1_scalar, GLV_BETA, GLV_S}; +use crate::curve::secp256k1::Secp256K1; +use crate::gadgets::biguint::{GeneratedValuesBigUint, WitnessBigUint}; +use crate::gadgets::curve::{AffinePointTarget, CircuitBuilderCurve}; +use crate::gadgets::curve_msm::curve_msm_circuit; +use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + +pub trait CircuitBuilderGlv, const D: usize> { + fn secp256k1_glv_beta(&mut self) -> NonNativeTarget; + + fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + BoolTarget, + BoolTarget, + ); + + fn glv_mul( + &mut self, + p: &AffinePointTarget, + k: &NonNativeTarget, + ) -> AffinePointTarget; +} + +impl, const D: usize> CircuitBuilderGlv + for CircuitBuilder +{ + fn secp256k1_glv_beta(&mut self) -> NonNativeTarget { + self.constant_nonnative(GLV_BETA) + } + + fn decompose_secp256k1_scalar( + &mut self, + k: &NonNativeTarget, + ) -> ( + NonNativeTarget, + NonNativeTarget, + BoolTarget, + BoolTarget, + ) { + let k1 = self.add_virtual_nonnative_target_sized::(4); + let k2 = self.add_virtual_nonnative_target_sized::(4); + let k1_neg = self.add_virtual_bool_target_unsafe(); + let k2_neg = self.add_virtual_bool_target_unsafe(); + + self.add_simple_generator(GLVDecompositionGenerator:: { + k: k.clone(), + k1: k1.clone(), + k2: k2.clone(), + k1_neg, + k2_neg, + _phantom: PhantomData, + }); + + // Check that `k1_raw + GLV_S * k2_raw == k`. + let k1_raw = self.nonnative_conditional_neg(&k1, k1_neg); + let k2_raw = self.nonnative_conditional_neg(&k2, k2_neg); + let s = self.constant_nonnative(GLV_S); + let mut should_be_k = self.mul_nonnative(&s, &k2_raw); + should_be_k = self.add_nonnative(&should_be_k, &k1_raw); + self.connect_nonnative(&should_be_k, k); + + (k1, k2, k1_neg, k2_neg) + } + + fn glv_mul( + &mut self, + p: &AffinePointTarget, + k: &NonNativeTarget, + ) -> AffinePointTarget { + let (k1, k2, k1_neg, k2_neg) = self.decompose_secp256k1_scalar(k); + + let beta = self.secp256k1_glv_beta(); + let beta_px = self.mul_nonnative(&beta, &p.x); + let sp = AffinePointTarget:: { + x: beta_px, + y: p.y.clone(), + }; + + let p_neg = self.curve_conditional_neg(p, k1_neg); + let sp_neg = self.curve_conditional_neg(&sp, k2_neg); + curve_msm_circuit(self, &p_neg, &sp_neg, &k1, &k2) + } +} + +#[derive(Debug)] +struct GLVDecompositionGenerator, const D: usize> { + k: NonNativeTarget, + k1: NonNativeTarget, + k2: NonNativeTarget, + k1_neg: BoolTarget, + k2_neg: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for GLVDecompositionGenerator +{ + fn dependencies(&self) -> Vec { + self.k.value.limbs.iter().map(|l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let k = Secp256K1Scalar::from_noncanonical_biguint( + witness.get_biguint_target(self.k.value.clone()), + ); + + let (k1, k2, k1_neg, k2_neg) = decompose_secp256k1_scalar(k); + + out_buffer.set_biguint_target(&self.k1.value, &k1.to_canonical_biguint()); + out_buffer.set_biguint_target(&self.k2.value, &k2.to_canonical_biguint()); + out_buffer.set_bool_target(self.k1_neg, k1_neg); + out_buffer.set_bool_target(self.k2_neg, k2_neg); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::Sample; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::curve::curve_types::{Curve, CurveScalar}; + use crate::curve::glv::glv_mul; + use crate::curve::secp256k1::Secp256K1; + use crate::gadgets::curve::CircuitBuilderCurve; + use crate::gadgets::glv::CircuitBuilderGlv; + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + #[ignore] + fn test_glv_gadget() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let rando = + (CurveScalar(Secp256K1Scalar::rand()) * Secp256K1::GENERATOR_PROJECTIVE).to_affine(); + let randot = builder.constant_affine_point(rando); + + let scalar = Secp256K1Scalar::rand(); + let scalar_target = builder.constant_nonnative(scalar); + + let rando_glv_scalar = glv_mul(rando.to_projective(), scalar); + let expected = builder.constant_affine_point(rando_glv_scalar.to_affine()); + let actual = builder.glv_mul(&randot, &scalar_target); + builder.connect_affine_point(&expected, &actual); + + dbg!(builder.num_gates()); + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + + data.verify(proof) + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs new file mode 100644 index 0000000..35b1010 --- /dev/null +++ b/src/gadgets/mod.rs @@ -0,0 +1,9 @@ +pub mod biguint; +pub mod curve; +pub mod curve_fixed_base; +pub mod curve_msm; +pub mod curve_windowed_mul; +pub mod ecdsa; +pub mod glv; +pub mod nonnative; +pub mod split_nonnative; diff --git a/src/gadgets/nonnative.rs b/src/gadgets/nonnative.rs new file mode 100644 index 0000000..f1c8f03 --- /dev/null +++ b/src/gadgets/nonnative.rs @@ -0,0 +1,826 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::marker::PhantomData; + +use num::{BigUint, Integer, One, Zero}; +use plonky2::field::extension::Extendable; +use plonky2::field::types::{Field, PrimeField}; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartitionWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::util::ceil_div_usize; +use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; +use plonky2_u32::gadgets::range_check::range_check_u32_circuit; +use plonky2_u32::witness::GeneratedValuesU32; + +use crate::gadgets::biguint::{ + BigUintTarget, CircuitBuilderBiguint, GeneratedValuesBigUint, WitnessBigUint, +}; + +#[derive(Clone, Debug)] +pub struct NonNativeTarget { + pub(crate) value: BigUintTarget, + pub(crate) _phantom: PhantomData, +} + +pub trait CircuitBuilderNonNative, const D: usize> { + fn num_nonnative_limbs() -> usize { + ceil_div_usize(FF::BITS, 32) + } + + fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget; + + fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget; + + fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget; + + fn zero_nonnative(&mut self) -> NonNativeTarget; + + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. + fn connect_nonnative( + &mut self, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, + ); + + fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget; + + fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget; + + fn add_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget; + + fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget; + + fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget; + + // Subtract two `NonNativeTarget`s. + fn sub_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget; + + fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget; + + fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget; + + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget; + + fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget; + + // Split a nonnative field element to bits. + fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec; + + fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget; +} + +impl, const D: usize> CircuitBuilderNonNative + for CircuitBuilder +{ + fn num_nonnative_limbs() -> usize { + ceil_div_usize(FF::BITS, 32) + } + + fn biguint_to_nonnative(&mut self, x: &BigUintTarget) -> NonNativeTarget { + NonNativeTarget { + value: x.clone(), + _phantom: PhantomData, + } + } + + fn nonnative_to_canonical_biguint( + &mut self, + x: &NonNativeTarget, + ) -> BigUintTarget { + x.value.clone() + } + + fn constant_nonnative(&mut self, x: FF) -> NonNativeTarget { + let x_biguint = self.constant_biguint(&x.to_canonical_biguint()); + self.biguint_to_nonnative(&x_biguint) + } + + fn zero_nonnative(&mut self) -> NonNativeTarget { + self.constant_nonnative(FF::ZERO) + } + + // Assert that two NonNativeTarget's, both assumed to be in reduced form, are equal. + fn connect_nonnative( + &mut self, + lhs: &NonNativeTarget, + rhs: &NonNativeTarget, + ) { + self.connect_biguint(&lhs.value, &rhs.value); + } + + fn add_virtual_nonnative_target(&mut self) -> NonNativeTarget { + let num_limbs = Self::num_nonnative_limbs::(); + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn add_virtual_nonnative_target_sized( + &mut self, + num_limbs: usize, + ) -> NonNativeTarget { + let value = self.add_virtual_biguint_target(num_limbs); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn add_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target_unsafe(); + + self.add_simple_generator(NonNativeAdditionGenerator:: { + a: a.clone(), + b: b.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + let sum_expected = self.add_biguint(&a.value, &b.value); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum + } + + fn mul_nonnative_by_bool( + &mut self, + a: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + NonNativeTarget { + value: self.mul_biguint_by_bool(&a.value, b), + _phantom: PhantomData, + } + } + + fn if_nonnative( + &mut self, + b: BoolTarget, + x: &NonNativeTarget, + y: &NonNativeTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let maybe_x = self.mul_nonnative_by_bool(x, b); + let maybe_y = self.mul_nonnative_by_bool(y, not_b); + self.add_nonnative(&maybe_x, &maybe_y) + } + + fn add_many_nonnative( + &mut self, + to_add: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_add.len() == 1 { + return to_add[0].clone(); + } + + let sum = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_u32_target(); + let summands = to_add.to_vec(); + + self.add_simple_generator(NonNativeMultipleAddsGenerator:: { + summands: summands.clone(), + sum: sum.clone(), + overflow, + _phantom: PhantomData, + }); + + range_check_u32_circuit(self, sum.value.limbs.clone()); + range_check_u32_circuit(self, vec![overflow]); + + let sum_expected = summands + .iter() + .fold(self.zero_biguint(), |a, b| self.add_biguint(&a, &b.value)); + + let modulus = self.constant_biguint(&FF::order()); + let overflow_biguint = BigUintTarget { + limbs: vec![overflow], + }; + let mod_times_overflow = self.mul_biguint(&modulus, &overflow_biguint); + let sum_actual = self.add_biguint(&sum.value, &mod_times_overflow); + self.connect_biguint(&sum_expected, &sum_actual); + + // Range-check result. + // TODO: can potentially leave unreduced until necessary (e.g. when connecting values). + let cmp = self.cmp_biguint(&sum.value, &modulus); + let one = self.one(); + self.connect(cmp.target, one); + + sum + } + + // Subtract two `NonNativeTarget`s. + fn sub_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let diff = self.add_virtual_nonnative_target::(); + let overflow = self.add_virtual_bool_target_unsafe(); + + self.add_simple_generator(NonNativeSubtractionGenerator:: { + a: a.clone(), + b: b.clone(), + diff: diff.clone(), + overflow, + _phantom: PhantomData, + }); + + range_check_u32_circuit(self, diff.value.limbs.clone()); + self.assert_bool(overflow); + + let diff_plus_b = self.add_biguint(&diff.value, &b.value); + let modulus = self.constant_biguint(&FF::order()); + let mod_times_overflow = self.mul_biguint_by_bool(&modulus, overflow); + let diff_plus_b_reduced = self.sub_biguint(&diff_plus_b, &mod_times_overflow); + self.connect_biguint(&a.value, &diff_plus_b_reduced); + + diff + } + + fn mul_nonnative( + &mut self, + a: &NonNativeTarget, + b: &NonNativeTarget, + ) -> NonNativeTarget { + let prod = self.add_virtual_nonnative_target::(); + let modulus = self.constant_biguint(&FF::order()); + let overflow = self.add_virtual_biguint_target( + a.value.num_limbs() + b.value.num_limbs() - modulus.num_limbs(), + ); + + self.add_simple_generator(NonNativeMultiplicationGenerator:: { + a: a.clone(), + b: b.clone(), + prod: prod.clone(), + overflow: overflow.clone(), + _phantom: PhantomData, + }); + + range_check_u32_circuit(self, prod.value.limbs.clone()); + range_check_u32_circuit(self, overflow.limbs.clone()); + + let prod_expected = self.mul_biguint(&a.value, &b.value); + + let mod_times_overflow = self.mul_biguint(&modulus, &overflow); + let prod_actual = self.add_biguint(&prod.value, &mod_times_overflow); + self.connect_biguint(&prod_expected, &prod_actual); + + prod + } + + fn mul_many_nonnative( + &mut self, + to_mul: &[NonNativeTarget], + ) -> NonNativeTarget { + if to_mul.len() == 1 { + return to_mul[0].clone(); + } + + let mut accumulator = self.mul_nonnative(&to_mul[0], &to_mul[1]); + for t in to_mul.iter().skip(2) { + accumulator = self.mul_nonnative(&accumulator, t); + } + accumulator + } + + fn neg_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let zero_target = self.constant_biguint(&BigUint::zero()); + let zero_ff = self.biguint_to_nonnative(&zero_target); + + self.sub_nonnative(&zero_ff, x) + } + + fn inv_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let num_limbs = x.value.num_limbs(); + let inv_biguint = self.add_virtual_biguint_target(num_limbs); + let div = self.add_virtual_biguint_target(num_limbs); + + self.add_simple_generator(NonNativeInverseGenerator:: { + x: x.clone(), + inv: inv_biguint.clone(), + div: div.clone(), + _phantom: PhantomData, + }); + + let product = self.mul_biguint(&x.value, &inv_biguint); + + let modulus = self.constant_biguint(&FF::order()); + let mod_times_div = self.mul_biguint(&modulus, &div); + let one = self.constant_biguint(&BigUint::one()); + let expected_product = self.add_biguint(&mod_times_div, &one); + self.connect_biguint(&product, &expected_product); + + NonNativeTarget:: { + value: inv_biguint, + _phantom: PhantomData, + } + } + + /// Returns `x % |FF|` as a `NonNativeTarget`. + fn reduce(&mut self, x: &BigUintTarget) -> NonNativeTarget { + let modulus = FF::order(); + let order_target = self.constant_biguint(&modulus); + let value = self.rem_biguint(x, &order_target); + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + fn reduce_nonnative(&mut self, x: &NonNativeTarget) -> NonNativeTarget { + let x_biguint = self.nonnative_to_canonical_biguint(x); + self.reduce(&x_biguint) + } + + fn bool_to_nonnative(&mut self, b: &BoolTarget) -> NonNativeTarget { + let limbs = vec![U32Target(b.target)]; + let value = BigUintTarget { limbs }; + + NonNativeTarget { + value, + _phantom: PhantomData, + } + } + + // Split a nonnative field element to bits. + fn split_nonnative_to_bits(&mut self, x: &NonNativeTarget) -> Vec { + let num_limbs = x.value.num_limbs(); + let mut result = Vec::with_capacity(num_limbs * 32); + + for i in 0..num_limbs { + let limb = x.value.get_limb(i); + let bit_targets = self.split_le_base::<2>(limb.0, 32); + let mut bits: Vec<_> = bit_targets + .iter() + .map(|&t| BoolTarget::new_unsafe(t)) + .collect(); + + result.append(&mut bits); + } + + result + } + + fn nonnative_conditional_neg( + &mut self, + x: &NonNativeTarget, + b: BoolTarget, + ) -> NonNativeTarget { + let not_b = self.not(b); + let neg = self.neg_nonnative(x); + let x_if_true = self.mul_nonnative_by_bool(&neg, b); + let x_if_false = self.mul_nonnative_by_bool(x, not_b); + + self.add_nonnative(&x_if_true, &x_if_false) + } +} + +#[derive(Debug)] +struct NonNativeAdditionGenerator, const D: usize, FF: PrimeField> { + a: NonNativeTarget, + b: NonNativeTarget, + sum: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeAdditionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + let sum_biguint = a_biguint + b_biguint; + let modulus = FF::order(); + let (overflow, sum_reduced) = if sum_biguint > modulus { + (true, sum_biguint - modulus) + } else { + (false, sum_biguint) + }; + + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultipleAddsGenerator, const D: usize, FF: PrimeField> +{ + summands: Vec>, + sum: NonNativeTarget, + overflow: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeMultipleAddsGenerator +{ + fn dependencies(&self) -> Vec { + self.summands + .iter() + .flat_map(|summand| summand.value.limbs.iter().map(|limb| limb.0)) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let summands: Vec<_> = self + .summands + .iter() + .map(|summand| { + FF::from_noncanonical_biguint(witness.get_biguint_target(summand.value.clone())) + }) + .collect(); + let summand_biguints: Vec<_> = summands + .iter() + .map(|summand| summand.to_canonical_biguint()) + .collect(); + + let sum_biguint = summand_biguints + .iter() + .fold(BigUint::zero(), |a, b| a + b.clone()); + + let modulus = FF::order(); + let (overflow_biguint, sum_reduced) = sum_biguint.div_rem(&modulus); + let overflow = overflow_biguint.to_u64_digits()[0] as u32; + + out_buffer.set_biguint_target(&self.sum.value, &sum_reduced); + out_buffer.set_u32_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeSubtractionGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + diff: NonNativeTarget, + overflow: BoolTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeSubtractionGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + + let modulus = FF::order(); + let (diff_biguint, overflow) = if a_biguint >= b_biguint { + (a_biguint - b_biguint, false) + } else { + (modulus + a_biguint - b_biguint, true) + }; + + out_buffer.set_biguint_target(&self.diff.value, &diff_biguint); + out_buffer.set_bool_target(self.overflow, overflow); + } +} + +#[derive(Debug)] +struct NonNativeMultiplicationGenerator, const D: usize, FF: Field> { + a: NonNativeTarget, + b: NonNativeTarget, + prod: NonNativeTarget, + overflow: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeMultiplicationGenerator +{ + fn dependencies(&self) -> Vec { + self.a + .value + .limbs + .iter() + .cloned() + .chain(self.b.value.limbs.clone()) + .map(|l| l.0) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let a = FF::from_noncanonical_biguint(witness.get_biguint_target(self.a.value.clone())); + let b = FF::from_noncanonical_biguint(witness.get_biguint_target(self.b.value.clone())); + let a_biguint = a.to_canonical_biguint(); + let b_biguint = b.to_canonical_biguint(); + + let prod_biguint = a_biguint * b_biguint; + + let modulus = FF::order(); + let (overflow_biguint, prod_reduced) = prod_biguint.div_rem(&modulus); + + out_buffer.set_biguint_target(&self.prod.value, &prod_reduced); + out_buffer.set_biguint_target(&self.overflow, &overflow_biguint); + } +} + +#[derive(Debug)] +struct NonNativeInverseGenerator, const D: usize, FF: PrimeField> { + x: NonNativeTarget, + inv: BigUintTarget, + div: BigUintTarget, + _phantom: PhantomData, +} + +impl, const D: usize, FF: PrimeField> SimpleGenerator + for NonNativeInverseGenerator +{ + fn dependencies(&self) -> Vec { + self.x.value.limbs.iter().map(|&l| l.0).collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = FF::from_noncanonical_biguint(witness.get_biguint_target(self.x.value.clone())); + let inv = x.inverse(); + + let x_biguint = x.to_canonical_biguint(); + let inv_biguint = inv.to_canonical_biguint(); + let prod = x_biguint * &inv_biguint; + let modulus = FF::order(); + let (div, _rem) = prod.div_rem(&modulus); + + out_buffer.set_biguint_target(&self.div, &div); + out_buffer.set_biguint_target(&self.inv, &inv_biguint); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::secp256k1_base::Secp256K1Base; + use plonky2::field::types::{Field, PrimeField, Sample}; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_builder::CircuitBuilder; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use crate::gadgets::nonnative::CircuitBuilderNonNative; + + #[test] + fn test_nonnative_add() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let sum_ff = x_ff + y_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let sum = builder.add_nonnative(&x, &y); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_many_adds() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let a_ff = FF::rand(); + let b_ff = FF::rand(); + let c_ff = FF::rand(); + let d_ff = FF::rand(); + let e_ff = FF::rand(); + let f_ff = FF::rand(); + let g_ff = FF::rand(); + let h_ff = FF::rand(); + let sum_ff = a_ff + b_ff + c_ff + d_ff + e_ff + f_ff + g_ff + h_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let a = builder.constant_nonnative(a_ff); + let b = builder.constant_nonnative(b_ff); + let c = builder.constant_nonnative(c_ff); + let d = builder.constant_nonnative(d_ff); + let e = builder.constant_nonnative(e_ff); + let f = builder.constant_nonnative(f_ff); + let g = builder.constant_nonnative(g_ff); + let h = builder.constant_nonnative(h_ff); + let all = [a, b, c, d, e, f, g, h]; + let sum = builder.add_many_nonnative(&all); + + let sum_expected = builder.constant_nonnative(sum_ff); + builder.connect_nonnative(&sum, &sum_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_sub() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let x_ff = FF::rand(); + let mut y_ff = FF::rand(); + while y_ff.to_canonical_biguint() > x_ff.to_canonical_biguint() { + y_ff = FF::rand(); + } + let diff_ff = x_ff - y_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let diff = builder.sub_nonnative(&x, &y); + + let diff_expected = builder.constant_nonnative(diff_ff); + builder.connect_nonnative(&diff, &diff_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_mul() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let x_ff = FF::rand(); + let y_ff = FF::rand(); + let product_ff = x_ff * y_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let y = builder.constant_nonnative(y_ff); + let product = builder.mul_nonnative(&x, &y); + + let product_expected = builder.constant_nonnative(product_ff); + builder.connect_nonnative(&product, &product_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_neg() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let x_ff = FF::rand(); + let neg_x_ff = -x_ff; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let neg_x = builder.neg_nonnative(&x); + + let neg_x_expected = builder.constant_nonnative(neg_x_ff); + builder.connect_nonnative(&neg_x, &neg_x_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_nonnative_inv() -> Result<()> { + type FF = Secp256K1Base; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let x_ff = FF::rand(); + let inv_x_ff = x_ff.inverse(); + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = builder.constant_nonnative(x_ff); + let inv_x = builder.inv_nonnative(&x); + + let inv_x_expected = builder.constant_nonnative(inv_x_ff); + builder.connect_nonnative(&inv_x, &inv_x_expected); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/src/gadgets/split_nonnative.rs b/src/gadgets/split_nonnative.rs new file mode 100644 index 0000000..977912e --- /dev/null +++ b/src/gadgets/split_nonnative.rs @@ -0,0 +1,131 @@ +use alloc::vec::Vec; +use core::marker::PhantomData; + +use itertools::Itertools; +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_u32::gadgets::arithmetic_u32::{CircuitBuilderU32, U32Target}; + +use crate::gadgets::biguint::BigUintTarget; +use crate::gadgets::nonnative::NonNativeTarget; + +pub trait CircuitBuilderSplit, const D: usize> { + fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec; + + fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec; + + fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec; + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget; +} + +impl, const D: usize> CircuitBuilderSplit + for CircuitBuilder +{ + fn split_u32_to_4_bit_limbs(&mut self, val: U32Target) -> Vec { + let two_bit_limbs = self.split_le_base::<4>(val.0, 16); + let four = self.constant(F::from_canonical_usize(4)); + let combined_limbs = two_bit_limbs + .iter() + .tuples() + .map(|(&a, &b)| self.mul_add(b, four, a)) + .collect(); + + combined_limbs + } + + fn split_nonnative_to_4_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_u32_to_4_bit_limbs(l)) + .collect() + } + + fn split_nonnative_to_2_bit_limbs( + &mut self, + val: &NonNativeTarget, + ) -> Vec { + val.value + .limbs + .iter() + .flat_map(|&l| self.split_le_base::<4>(l.0, 16)) + .collect() + } + + // Note: assumes its inputs are 4-bit limbs, and does not range-check. + fn recombine_nonnative_4_bit_limbs( + &mut self, + limbs: Vec, + ) -> NonNativeTarget { + let base = self.constant_u32(1 << 4); + let u32_limbs = limbs + .chunks(8) + .map(|chunk| { + let mut combined_chunk = self.zero_u32(); + for i in (0..8).rev() { + let (low, _high) = self.mul_add_u32(combined_chunk, base, U32Target(chunk[i])); + combined_chunk = low; + } + combined_chunk + }) + .collect(); + + NonNativeTarget { + value: BigUintTarget { limbs: u32_limbs }, + _phantom: PhantomData, + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::secp256k1_scalar::Secp256K1Scalar; + use plonky2::field::types::Sample; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + use super::*; + use crate::gadgets::nonnative::{CircuitBuilderNonNative, NonNativeTarget}; + + #[test] + fn test_split_nonnative() -> Result<()> { + type FF = Secp256K1Scalar; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_ecc_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let x = FF::rand(); + let x_target = builder.constant_nonnative(x); + let split = builder.split_nonnative_to_4_bit_limbs(&x_target); + let combined: NonNativeTarget = + builder.recombine_nonnative_4_bit_limbs(split); + builder.connect_nonnative(&x_target, &combined); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..bf84913 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +#![allow(clippy::needless_range_loop)] +#![cfg_attr(not(test), no_std)] + +extern crate alloc; + +pub mod curve; +pub mod gadgets;