commit a34a23fec3cf0c290dc82c2fefd666e9f4fa0210 Author: Nicholas Ward Date: Fri Mar 3 16:05:41 2023 -0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ff35523 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +# Cargo build +/target +Cargo.lock + +# Profile-guided optimization +/tmp +pgo-data.profdata + +# MacOS nuisances +.DS_Store + diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..bde1b53 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "plonky2_u32" +description = "u32 gadget for Plonky2" +version = "0.1.0" +license = "MIT OR Apache-2.0" +repository = "https://github.com/mir-protocol/plonky2" +edition = "2021" + +[dependencies] +anyhow = { version = "1.0.40", default-features = false } +itertools = { version = "0.10.0", default-features = false } +num = { version = "0.4", default-features = false } +plonky2 = { version = "0.1.2", default-features = false } + +[dev-dependencies] +plonky2 = { version = "0.1.2", default-features = false, features = ["gate_testing"] } +rand = { version = "0.8.4", default-features = false, features = ["getrandom"] } diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..1e5006d --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..86d690b --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2022 The Plonky2 Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..bb4e2d8 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +## License + +Licensed under either of + +* Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) +* MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) + +at your option. + + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/src/gadgets/arithmetic_u32.rs b/src/gadgets/arithmetic_u32.rs new file mode 100644 index 0000000..65f5ac0 --- /dev/null +++ b/src/gadgets/arithmetic_u32.rs @@ -0,0 +1,303 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::marker::PhantomData; + +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartitionWitness, Witness}; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::gates::add_many_u32::U32AddManyGate; +use crate::gates::arithmetic_u32::U32ArithmeticGate; +use crate::gates::subtraction_u32::U32SubtractionGate; +use crate::witness::GeneratedValuesU32; + +#[derive(Clone, Copy, Debug)] +pub struct U32Target(pub Target); + +pub trait CircuitBuilderU32, const D: usize> { + fn add_virtual_u32_target(&mut self) -> U32Target; + + fn add_virtual_u32_targets(&mut self, n: usize) -> Vec; + + /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. + fn constant_u32(&mut self, c: u32) -> U32Target; + + fn zero_u32(&mut self) -> U32Target; + + fn one_u32(&mut self) -> U32Target; + + fn connect_u32(&mut self, x: U32Target, y: U32Target); + + fn assert_zero_u32(&mut self, x: U32Target); + + /// Checks for special cases where the value of + /// `x * y + z` + /// can be determined without adding a `U32ArithmeticGate`. + fn arithmetic_u32_special_cases( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> Option<(U32Target, U32Target)>; + + // Returns x * y + z. + fn mul_add_u32(&mut self, x: U32Target, y: U32Target, z: U32Target) -> (U32Target, U32Target); + + fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target); + + fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target); + + fn add_u32s_with_carry( + &mut self, + to_add: &[U32Target], + carry: U32Target, + ) -> (U32Target, U32Target); + + fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target); + + // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). + fn sub_u32(&mut self, x: U32Target, y: U32Target, borrow: U32Target) -> (U32Target, U32Target); +} + +impl, const D: usize> CircuitBuilderU32 + for CircuitBuilder +{ + fn add_virtual_u32_target(&mut self) -> U32Target { + U32Target(self.add_virtual_target()) + } + + fn add_virtual_u32_targets(&mut self, n: usize) -> Vec { + self.add_virtual_targets(n) + .into_iter() + .map(U32Target) + .collect() + } + + /// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits. + fn constant_u32(&mut self, c: u32) -> U32Target { + U32Target(self.constant(F::from_canonical_u32(c))) + } + + fn zero_u32(&mut self) -> U32Target { + U32Target(self.zero()) + } + + fn one_u32(&mut self) -> U32Target { + U32Target(self.one()) + } + + fn connect_u32(&mut self, x: U32Target, y: U32Target) { + self.connect(x.0, y.0) + } + + fn assert_zero_u32(&mut self, x: U32Target) { + self.assert_zero(x.0) + } + + /// Checks for special cases where the value of + /// `x * y + z` + /// can be determined without adding a `U32ArithmeticGate`. + fn arithmetic_u32_special_cases( + &mut self, + x: U32Target, + y: U32Target, + z: U32Target, + ) -> Option<(U32Target, U32Target)> { + let x_const = self.target_as_constant(x.0); + let y_const = self.target_as_constant(y.0); + let z_const = self.target_as_constant(z.0); + + // If both terms are constant, return their (constant) sum. + let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) { + Some(xx * yy) + } else { + None + }; + + if let (Some(a), Some(b)) = (first_term_const, z_const) { + let sum = (a + b).to_canonical_u64(); + let (low, high) = (sum as u32, (sum >> 32) as u32); + return Some((self.constant_u32(low), self.constant_u32(high))); + } + + None + } + + // Returns x * y + z. + fn mul_add_u32(&mut self, x: U32Target, y: U32Target, z: U32Target) -> (U32Target, U32Target) { + if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) { + return result; + } + + let gate = U32ArithmeticGate::::new_from_config(&self.config); + let (row, copy) = self.find_slot(gate, &[], &[]); + + self.connect(Target::wire(row, gate.wire_ith_multiplicand_0(copy)), x.0); + self.connect(Target::wire(row, gate.wire_ith_multiplicand_1(copy)), y.0); + self.connect(Target::wire(row, gate.wire_ith_addend(copy)), z.0); + + let output_low = U32Target(Target::wire(row, gate.wire_ith_output_low_half(copy))); + let output_high = U32Target(Target::wire(row, gate.wire_ith_output_high_half(copy))); + + (output_low, output_high) + } + + fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + let one = self.one_u32(); + self.mul_add_u32(a, one, b) + } + + fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) { + match to_add.len() { + 0 => (self.zero_u32(), self.zero_u32()), + 1 => (to_add[0], self.zero_u32()), + 2 => self.add_u32(to_add[0], to_add[1]), + _ => { + let num_addends = to_add.len(); + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (row, copy) = + self.find_slot(gate, &[F::from_canonical_usize(num_addends)], &[]); + + for j in 0..num_addends { + self.connect( + Target::wire(row, gate.wire_ith_op_jth_addend(copy, j)), + to_add[j].0, + ); + } + let zero = self.zero(); + self.connect(Target::wire(row, gate.wire_ith_carry(copy)), zero); + + let output_low = U32Target(Target::wire(row, gate.wire_ith_output_result(copy))); + let output_high = U32Target(Target::wire(row, gate.wire_ith_output_carry(copy))); + + (output_low, output_high) + } + } + } + + fn add_u32s_with_carry( + &mut self, + to_add: &[U32Target], + carry: U32Target, + ) -> (U32Target, U32Target) { + if to_add.len() == 1 { + return self.add_u32(to_add[0], carry); + } + + let num_addends = to_add.len(); + + let gate = U32AddManyGate::::new_from_config(&self.config, num_addends); + let (row, copy) = self.find_slot(gate, &[F::from_canonical_usize(num_addends)], &[]); + + for j in 0..num_addends { + self.connect( + Target::wire(row, gate.wire_ith_op_jth_addend(copy, j)), + to_add[j].0, + ); + } + self.connect(Target::wire(row, gate.wire_ith_carry(copy)), carry.0); + + let output = U32Target(Target::wire(row, gate.wire_ith_output_result(copy))); + let output_carry = U32Target(Target::wire(row, gate.wire_ith_output_carry(copy))); + + (output, output_carry) + } + + fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) { + let zero = self.zero_u32(); + self.mul_add_u32(a, b, zero) + } + + // Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x). + fn sub_u32(&mut self, x: U32Target, y: U32Target, borrow: U32Target) -> (U32Target, U32Target) { + let gate = U32SubtractionGate::::new_from_config(&self.config); + let (row, copy) = self.find_slot(gate, &[], &[]); + + self.connect(Target::wire(row, gate.wire_ith_input_x(copy)), x.0); + self.connect(Target::wire(row, gate.wire_ith_input_y(copy)), y.0); + self.connect( + Target::wire(row, gate.wire_ith_input_borrow(copy)), + borrow.0, + ); + + let output_result = U32Target(Target::wire(row, gate.wire_ith_output_result(copy))); + let output_borrow = U32Target(Target::wire(row, gate.wire_ith_output_borrow(copy))); + + (output_result, output_borrow) + } +} + +#[derive(Debug)] +struct SplitToU32Generator, const D: usize> { + x: Target, + low: U32Target, + high: U32Target, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for SplitToU32Generator +{ + fn dependencies(&self) -> Vec { + vec![self.x] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let x = witness.get_target(self.x); + let x_u64 = x.to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + + out_buffer.set_u32_target(self.low, low); + out_buffer.set_u32_target(self.high, high); + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + #[test] + pub fn test_add_many_u32s() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + const NUM_ADDENDS: usize = 15; + + let config = CircuitConfig::standard_recursion_config(); + + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = OsRng; + let mut to_add = Vec::new(); + let mut sum = 0u64; + for _ in 0..NUM_ADDENDS { + let x: u32 = rng.gen(); + sum += x as u64; + to_add.push(builder.constant_u32(x)); + } + let carry = builder.zero_u32(); + let (result_low, result_high) = builder.add_u32s_with_carry(&to_add, carry); + let expected_low = builder.constant_u32((sum % (1 << 32)) as u32); + let expected_high = builder.constant_u32((sum >> 32) as u32); + + builder.connect_u32(result_low, expected_low); + builder.connect_u32(result_high, expected_high); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } +} diff --git a/src/gadgets/mod.rs b/src/gadgets/mod.rs new file mode 100644 index 0000000..622242e --- /dev/null +++ b/src/gadgets/mod.rs @@ -0,0 +1,3 @@ +pub mod arithmetic_u32; +pub mod multiple_comparison; +pub mod range_check; diff --git a/src/gadgets/multiple_comparison.rs b/src/gadgets/multiple_comparison.rs new file mode 100644 index 0000000..8d82c29 --- /dev/null +++ b/src/gadgets/multiple_comparison.rs @@ -0,0 +1,152 @@ +use alloc::vec; +use alloc::vec::Vec; + +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::util::ceil_div_usize; + +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gates::comparison::ComparisonGate; + +/// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value. +/// This range-checks its inputs. +pub fn list_le_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: Vec, + b: Vec, + num_bits: usize, +) -> BoolTarget { + assert_eq!( + a.len(), + b.len(), + "Comparison must be between same number of inputs and outputs" + ); + let n = a.len(); + + let chunk_bits = 2; + let num_chunks = ceil_div_usize(num_bits, chunk_bits); + + let one = builder.one(); + let mut result = one; + for i in 0..n { + let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks); + let a_le_b_row = builder.add_gate(a_le_b_gate.clone(), vec![]); + builder.connect( + Target::wire(a_le_b_row, a_le_b_gate.wire_first_input()), + a[i], + ); + builder.connect( + Target::wire(a_le_b_row, a_le_b_gate.wire_second_input()), + b[i], + ); + let a_le_b_result = Target::wire(a_le_b_row, a_le_b_gate.wire_result_bool()); + + let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks); + let b_le_a_row = builder.add_gate(b_le_a_gate.clone(), vec![]); + builder.connect( + Target::wire(b_le_a_row, b_le_a_gate.wire_first_input()), + b[i], + ); + builder.connect( + Target::wire(b_le_a_row, b_le_a_gate.wire_second_input()), + a[i], + ); + let b_le_a_result = Target::wire(b_le_a_row, b_le_a_gate.wire_result_bool()); + + let these_limbs_equal = builder.mul(a_le_b_result, b_le_a_result); + let these_limbs_less_than = builder.sub(one, b_le_a_result); + result = builder.mul_add(these_limbs_equal, result, these_limbs_less_than); + } + + // `result` being boolean is an invariant, maintained because its new value is always + // `x * result + y`, where `x` and `y` are booleans that are not simultaneously true. + BoolTarget::new_unsafe(result) +} + +/// Helper function for comparing, specifically, lists of `U32Target`s. +pub fn list_le_u32_circuit, const D: usize>( + builder: &mut CircuitBuilder, + a: Vec, + b: Vec, +) -> BoolTarget { + let a_targets: Vec = a.iter().map(|&t| t.0).collect(); + let b_targets: Vec = b.iter().map(|&t| t.0).collect(); + + list_le_circuit(builder, a_targets, b_targets, 32) +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use num::BigUint; + use plonky2::field::types::Field; + use plonky2::iop::witness::PartialWitness; + use plonky2::plonk::circuit_data::CircuitConfig; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + fn test_list_le(size: usize, num_bits: usize) -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + let config = CircuitConfig::standard_recursion_config(); + let pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + let mut rng = OsRng; + + let lst1: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); + let lst2: Vec = (0..size) + .map(|_| rng.gen_range(0..(1 << num_bits))) + .collect(); + + let a_biguint = BigUint::from_slice( + &lst1 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + let b_biguint = BigUint::from_slice( + &lst2 + .iter() + .flat_map(|&x| [x as u32, (x >> 32) as u32]) + .collect::>(), + ); + + let a = lst1 + .iter() + .map(|&x| builder.constant(F::from_canonical_u64(x))) + .collect(); + let b = lst2 + .iter() + .map(|&x| builder.constant(F::from_canonical_u64(x))) + .collect(); + + let result = list_le_circuit(&mut builder, a, b, num_bits); + + let expected_result = builder.constant_bool(a_biguint <= b_biguint); + builder.connect(result.target, expected_result.target); + + let data = builder.build::(); + let proof = data.prove(pw).unwrap(); + data.verify(proof) + } + + #[test] + fn test_multiple_comparison() -> Result<()> { + for size in [1, 3, 6] { + for num_bits in [20, 32, 40, 44] { + test_list_le(size, num_bits).unwrap(); + } + } + + Ok(()) + } +} diff --git a/src/gadgets/range_check.rs b/src/gadgets/range_check.rs new file mode 100644 index 0000000..9e8cf2a --- /dev/null +++ b/src/gadgets/range_check.rs @@ -0,0 +1,23 @@ +use alloc::vec; +use alloc::vec::Vec; + +use plonky2::field::extension::Extendable; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::target::Target; +use plonky2::plonk::circuit_builder::CircuitBuilder; + +use crate::gadgets::arithmetic_u32::U32Target; +use crate::gates::range_check_u32::U32RangeCheckGate; + +pub fn range_check_u32_circuit, const D: usize>( + builder: &mut CircuitBuilder, + vals: Vec, +) { + let num_input_limbs = vals.len(); + let gate = U32RangeCheckGate::::new(num_input_limbs); + let row = builder.add_gate(gate, vec![]); + + for i in 0..num_input_limbs { + builder.connect(Target::wire(row, gate.wire_ith_input_limb(i)), vals[i].0); + } +} diff --git a/src/gates/add_many_u32.rs b/src/gates/add_many_u32.rs new file mode 100644 index 0000000..566a782 --- /dev/null +++ b/src/gates/add_many_u32.rs @@ -0,0 +1,456 @@ +use alloc::boxed::Box; +use alloc::format; +use alloc::string::String; +use alloc::vec::Vec; +use core::marker::PhantomData; + +use itertools::unfold; +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::gates::gate::Gate; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use plonky2::util::ceil_div_usize; + +const LOG2_MAX_NUM_ADDENDS: usize = 4; +const MAX_NUM_ADDENDS: usize = 16; + +/// A gate to perform addition on `num_addends` different 32-bit values, plus a small carry +#[derive(Copy, Clone, Debug)] +pub struct U32AddManyGate, const D: usize> { + pub num_addends: usize, + pub num_ops: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32AddManyGate { + pub fn new_from_config(config: &CircuitConfig, num_addends: usize) -> Self { + Self { + num_addends, + num_ops: Self::num_ops(num_addends, config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(num_addends: usize, config: &CircuitConfig) -> usize { + debug_assert!(num_addends <= MAX_NUM_ADDENDS); + let wires_per_op = (num_addends + 3) + Self::num_limbs(); + let routed_wires_per_op = num_addends + 3; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_op_jth_addend(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < self.num_addends); + (self.num_addends + 3) * i + j + } + pub fn wire_ith_carry(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + } + + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + 1 + } + pub fn wire_ith_output_carry(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + (self.num_addends + 3) * i + self.num_addends + 2 + } + + pub fn limb_bits() -> usize { + 2 + } + pub fn num_result_limbs() -> usize { + ceil_div_usize(32, Self::limb_bits()) + } + pub fn num_carry_limbs() -> usize { + ceil_div_usize(LOG2_MAX_NUM_ADDENDS, Self::limb_bits()) + } + pub fn num_limbs() -> usize { + Self::num_result_limbs() + Self::num_carry_limbs() + } + + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < Self::num_limbs()); + (self.num_addends + 3) * self.num_ops + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32AddManyGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let addends: Vec = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let computed_output = addends.iter().fold(F::Extension::ZERO, |x, &y| x + y) + carry; + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base = F::Extension::from_canonical_u64(1 << 32u64); + let combined_output = output_carry * base + output_result; + + constraints.push(combined_output - computed_output); + + let mut combined_result_limbs = F::Extension::ZERO; + let mut combined_carry_limbs = F::Extension::ZERO; + let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = base * combined_result_limbs + this_limb; + } else { + combined_carry_limbs = base * combined_carry_limbs + this_limb; + } + } + constraints.push(combined_result_limbs - output_result); + constraints.push(combined_carry_limbs - output_carry); + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + for i in 0..self.num_ops { + let addends: Vec = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let computed_output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base = F::from_canonical_u64(1 << 32u64); + let combined_output = output_carry * base + output_result; + + yield_constr.one(combined_output - computed_output); + + let mut combined_result_limbs = F::ZERO; + let mut combined_carry_limbs = F::ZERO; + let base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = base * combined_result_limbs + this_limb; + } else { + combined_carry_limbs = base * combined_carry_limbs + this_limb; + } + } + yield_constr.one(combined_result_limbs - output_result); + yield_constr.one(combined_carry_limbs - output_carry); + } + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + for i in 0..self.num_ops { + let addends: Vec> = (0..self.num_addends) + .map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)]) + .collect(); + let carry = vars.local_wires[self.wire_ith_carry(i)]; + + let mut computed_output = carry; + for addend in addends { + computed_output = builder.add_extension(computed_output, addend); + } + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_carry = vars.local_wires[self.wire_ith_output_carry(i)]; + + let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); + let base_target = builder.constant_extension(base); + let combined_output = + builder.mul_add_extension(output_carry, base_target, output_result); + + constraints.push(builder.sub_extension(combined_output, computed_output)); + + let mut combined_result_limbs = builder.zero_extension(); + let mut combined_carry_limbs = builder.zero_extension(); + let base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + if j < Self::num_result_limbs() { + combined_result_limbs = + builder.mul_add_extension(base, combined_result_limbs, this_limb); + } else { + combined_carry_limbs = + builder.mul_add_extension(base, combined_carry_limbs, this_limb); + } + } + constraints.push(builder.sub_extension(combined_result_limbs, output_result)); + constraints.push(builder.sub_extension(combined_carry_limbs, output_carry)); + } + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + U32AddManyGenerator { + gate: *self, + row, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + (self.num_addends + 3) * self.num_ops + Self::num_limbs() * self.num_ops + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + self.num_ops * (3 + Self::num_limbs()) + } +} + +#[derive(Clone, Debug)] +struct U32AddManyGenerator, const D: usize> { + gate: U32AddManyGate, + row: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32AddManyGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| Target::wire(self.row, column); + + (0..self.gate.num_addends) + .map(|j| local_target(self.gate.wire_ith_op_jth_addend(self.i, j))) + .chain([local_target(self.gate.wire_ith_carry(self.i))]) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let addends: Vec<_> = (0..self.gate.num_addends) + .map(|j| get_local_wire(self.gate.wire_ith_op_jth_addend(self.i, j))) + .collect(); + let carry = get_local_wire(self.gate.wire_ith_carry(self.i)); + + let output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry; + let output_u64 = output.to_canonical_u64(); + + let output_carry_u64 = output_u64 >> 32; + let output_result_u64 = output_u64 & ((1 << 32) - 1); + + let output_carry = F::from_canonical_u64(output_carry_u64); + let output_result = F::from_canonical_u64(output_result_u64); + + let output_carry_wire = local_wire(self.gate.wire_ith_output_carry(self.i)); + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + + out_buffer.set_wire(output_carry_wire, output_carry); + out_buffer.set_wire(output_result_wire, output_result); + + let num_result_limbs = U32AddManyGate::::num_result_limbs(); + let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); + let limb_base = 1 << U32AddManyGate::::limb_bits(); + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % limb_base; + val /= limb_base; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let result_limbs = split_to_limbs(output_result_u64, num_result_limbs); + let carry_limbs = split_to_limbs(output_carry_u64, num_carry_limbs); + + for (j, limb) in result_limbs.chain(carry_limbs).enumerate() { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, limb); + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::extension::quartic::QuarticExtension; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Sample; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + #[test] + fn low_degree() { + test_low_degree::(U32AddManyGate:: { + num_addends: 4, + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32AddManyGate:: { + num_addends: 4, + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const NUM_ADDENDS: usize = 10; + const NUM_U32_ADD_MANY_OPS: usize = 3; + + fn get_wires(addends: Vec>, carries: Vec) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let num_result_limbs = U32AddManyGate::::num_result_limbs(); + let num_carry_limbs = U32AddManyGate::::num_carry_limbs(); + let limb_base = 1 << U32AddManyGate::::limb_bits(); + for op in 0..NUM_U32_ADD_MANY_OPS { + let adds = &addends[op]; + let ca = carries[op]; + + let output = adds.iter().sum::() + ca; + let output_result = output & ((1 << 32) - 1); + let output_carry = output >> 32; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % limb_base; + val /= limb_base; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut result_limbs: Vec<_> = + split_to_limbs(output_result, num_result_limbs).collect(); + let mut carry_limbs: Vec<_> = + split_to_limbs(output_carry, num_carry_limbs).collect(); + + for a in adds { + v0.push(F::from_canonical_u64(*a)); + } + v0.push(F::from_canonical_u64(ca)); + v0.push(F::from_canonical_u64(output_result)); + v0.push(F::from_canonical_u64(output_carry)); + v1.append(&mut result_limbs); + v1.append(&mut carry_limbs); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + + let mut rng = OsRng; + let addends: Vec> = (0..NUM_U32_ADD_MANY_OPS) + .map(|_| (0..NUM_ADDENDS).map(|_| rng.gen::() as u64).collect()) + .collect(); + let carries: Vec<_> = (0..NUM_U32_ADD_MANY_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + + let gate = U32AddManyGate:: { + num_addends: NUM_ADDENDS, + num_ops: NUM_U32_ADD_MANY_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(addends, carries), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/arithmetic_u32.rs b/src/gates/arithmetic_u32.rs new file mode 100644 index 0000000..c65b32a --- /dev/null +++ b/src/gates/arithmetic_u32.rs @@ -0,0 +1,575 @@ +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; + +use itertools::unfold; +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::gates::gate::Gate; +use plonky2::gates::packed_util::PackedEvaluableBase; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; + +/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand). +#[derive(Copy, Clone, Debug)] +pub struct U32ArithmeticGate, const D: usize> { + pub num_ops: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32ArithmeticGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = Self::routed_wires_per_op() + Self::num_limbs(); + (config.num_wires / wires_per_op).min(config.num_routed_wires / Self::routed_wires_per_op()) + } + + pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + } + pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + 1 + } + pub fn wire_ith_addend(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + 2 + } + + pub fn wire_ith_output_low_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + 3 + } + + pub fn wire_ith_output_high_half(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + 4 + } + + pub fn wire_ith_inverse(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + Self::routed_wires_per_op() * i + 5 + } + + pub fn limb_bits() -> usize { + 2 + } + pub fn num_limbs() -> usize { + 64 / Self::limb_bits() + } + pub fn routed_wires_per_op() -> usize { + 6 + } + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < Self::num_limbs()); + Self::routed_wires_per_op() * self.num_ops + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32ArithmeticGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; + + let computed_output = multiplicand_0 * multiplicand_1 + addend; + + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + let inverse = vars.local_wires[self.wire_ith_inverse(i)]; + + // Check canonicity of combined_output = output_high * 2^32 + output_low + let combined_output = { + let base = F::Extension::from_canonical_u64(1 << 32u64); + let one = F::Extension::ONE; + let u32_max = F::Extension::from_canonical_u32(u32::MAX); + + // This is zero if and only if the high limb is `u32::MAX`. + // u32::MAX - output_high + let diff = u32_max - output_high; + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + // inverse * diff - 1 + let hi_not_max = inverse * diff - one; + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + // hi_not_max * limb_0_u32 + let hi_not_max_or_lo_zero = hi_not_max * output_low; + + constraints.push(hi_not_max_or_lo_zero); + + output_high * base + output_low + }; + + constraints.push(combined_output - computed_output); + + let mut combined_low_limbs = F::Extension::ZERO; + let mut combined_high_limbs = F::Extension::ZERO; + let midpoint = Self::num_limbs() / 2; + let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + if j < midpoint { + combined_low_limbs = base * combined_low_limbs + this_limb; + } else { + combined_high_limbs = base * combined_high_limbs + this_limb; + } + } + constraints.push(combined_low_limbs - output_low); + constraints.push(combined_high_limbs - output_high); + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, + ) { + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; + + let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend); + + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + let inverse = vars.local_wires[self.wire_ith_inverse(i)]; + + // Check canonicity of combined_output = output_high * 2^32 + output_low + let combined_output = { + let base: F::Extension = F::from_canonical_u64(1 << 32u64).into(); + let base_target = builder.constant_extension(base); + let one = builder.one_extension(); + let u32_max = + builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX)); + + // This is zero if and only if the high limb is `u32::MAX`. + let diff = builder.sub_extension(u32_max, output_high); + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + let hi_not_max = builder.mul_sub_extension(inverse, diff, one); + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, output_low); + + constraints.push(hi_not_max_or_lo_zero); + + builder.mul_add_extension(output_high, base_target, output_low) + }; + + constraints.push(builder.sub_extension(combined_output, computed_output)); + + let mut combined_low_limbs = builder.zero_extension(); + let mut combined_high_limbs = builder.zero_extension(); + let midpoint = Self::num_limbs() / 2; + let base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + if j < midpoint { + combined_low_limbs = + builder.mul_add_extension(base, combined_low_limbs, this_limb); + } else { + combined_high_limbs = + builder.mul_add_extension(base, combined_high_limbs, this_limb); + } + } + + constraints.push(builder.sub_extension(combined_low_limbs, output_low)); + constraints.push(builder.sub_extension(combined_high_limbs, output_high)); + } + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + U32ArithmeticGenerator { + gate: *self, + row, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + self.num_ops * (Self::routed_wires_per_op() + Self::num_limbs()) + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + self.num_ops * (4 + Self::num_limbs()) + } +} + +impl, const D: usize> PackedEvaluableBase + for U32ArithmeticGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for i in 0..self.num_ops { + let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)]; + let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)]; + let addend = vars.local_wires[self.wire_ith_addend(i)]; + + let computed_output = multiplicand_0 * multiplicand_1 + addend; + + let output_low = vars.local_wires[self.wire_ith_output_low_half(i)]; + let output_high = vars.local_wires[self.wire_ith_output_high_half(i)]; + let inverse = vars.local_wires[self.wire_ith_inverse(i)]; + + let combined_output = { + let base = P::from(F::from_canonical_u64(1 << 32u64)); + let one = P::ONES; + let u32_max = P::from(F::from_canonical_u32(u32::MAX)); + + // This is zero if and only if the high limb is `u32::MAX`. + // u32::MAX - output_high + let diff = u32_max - output_high; + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + // inverse * diff - 1 + let hi_not_max = inverse * diff - one; + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + // hi_not_max * limb_0_u32 + let hi_not_max_or_lo_zero = hi_not_max * output_low; + + yield_constr.one(hi_not_max_or_lo_zero); + + output_high * base + output_low + }; + + yield_constr.one(combined_output - computed_output); + + let mut combined_low_limbs = P::ZEROS; + let mut combined_high_limbs = P::ZEROS; + let midpoint = Self::num_limbs() / 2; + let base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + if j < midpoint { + combined_low_limbs = combined_low_limbs * base + this_limb; + } else { + combined_high_limbs = combined_high_limbs * base + this_limb; + } + } + yield_constr.one(combined_low_limbs - output_low); + yield_constr.one(combined_high_limbs - output_high); + } + } +} + +#[derive(Clone, Debug)] +struct U32ArithmeticGenerator, const D: usize> { + gate: U32ArithmeticGate, + row: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32ArithmeticGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| Target::wire(self.row, column); + + vec![ + local_target(self.gate.wire_ith_multiplicand_0(self.i)), + local_target(self.gate.wire_ith_multiplicand_1(self.i)), + local_target(self.gate.wire_ith_addend(self.i)), + ] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i)); + let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i)); + let addend = get_local_wire(self.gate.wire_ith_addend(self.i)); + + let output = multiplicand_0 * multiplicand_1 + addend; + let mut output_u64 = output.to_canonical_u64(); + + let output_high_u64 = output_u64 >> 32; + let output_low_u64 = output_u64 & ((1 << 32) - 1); + + let output_high = F::from_canonical_u64(output_high_u64); + let output_low = F::from_canonical_u64(output_low_u64); + + let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i)); + let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i)); + + out_buffer.set_wire(output_high_wire, output_high); + out_buffer.set_wire(output_low_wire, output_low); + + let diff = u32::MAX as u64 - output_high_u64; + let inverse = if diff == 0 { + F::ZERO + } else { + F::from_canonical_u64(diff).inverse() + }; + let inverse_wire = local_wire(self.gate.wire_ith_inverse(self.i)); + out_buffer.set_wire(inverse_wire, inverse); + + let num_limbs = U32ArithmeticGate::::num_limbs(); + let limb_base = 1 << U32ArithmeticGate::::limb_bits(); + let output_limbs_u64 = unfold((), move |_| { + let ret = output_u64 % limb_base; + output_u64 /= limb_base; + Some(ret) + }) + .take(num_limbs); + let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64); + + for (j, output_limb) in output_limbs_f.enumerate() { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, output_limb); + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::Sample; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + #[test] + fn low_degree() { + test_low_degree::(U32ArithmeticGate:: { + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32ArithmeticGate:: { + num_ops: 3, + _phantom: PhantomData, + }) + } + + fn get_wires< + F: RichField + Extendable, + FF: From, + const D: usize, + const NUM_U32_ARITHMETIC_OPS: usize, + >( + multiplicands_0: Vec, + multiplicands_1: Vec, + addends: Vec, + ) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32ArithmeticGate::::limb_bits(); + let num_limbs = U32ArithmeticGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_ARITHMETIC_OPS { + let m0 = multiplicands_0[c]; + let m1 = multiplicands_1[c]; + let a = addends[c]; + + let mut output = m0 * m1 + a; + let output_low = output & ((1 << 32) - 1); + let output_high = output >> 32; + let diff = u32::MAX as u64 - output_high; + let inverse = if diff == 0 { + F::ZERO + } else { + F::from_canonical_u64(diff).inverse() + }; + + let mut output_limbs = Vec::with_capacity(num_limbs); + for _i in 0..num_limbs { + output_limbs.push(output % limb_base); + output /= limb_base; + } + let mut output_limbs_f: Vec<_> = output_limbs + .into_iter() + .map(F::from_canonical_u64) + .collect(); + + v0.push(F::from_canonical_u64(m0)); + v0.push(F::from_canonical_u64(m1)); + v0.push(F::from_noncanonical_u64(a)); + v0.push(F::from_canonical_u64(output_low)); + v0.push(F::from_canonical_u64(output_high)); + v0.push(inverse); + v1.append(&mut output_limbs_f); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + + #[test] + fn test_gate_constraint() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + const NUM_U32_ARITHMETIC_OPS: usize = 3; + + let mut rng = OsRng; + let multiplicands_0: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let multiplicands_1: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let addends: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + + let gate = U32ArithmeticGate:: { + num_ops: NUM_U32_ARITHMETIC_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires::( + multiplicands_0, + multiplicands_1, + addends, + ), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_canonicity() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + const NUM_U32_ARITHMETIC_OPS: usize = 3; + + let multiplicands_0 = vec![0; NUM_U32_ARITHMETIC_OPS]; + let multiplicands_1 = vec![0; NUM_U32_ARITHMETIC_OPS]; + // A non-canonical addend will produce a non-canonical output using + // get_wires. + let addends = vec![0xFFFFFFFF00000001; NUM_U32_ARITHMETIC_OPS]; + + let gate = U32ArithmeticGate:: { + num_ops: NUM_U32_ARITHMETIC_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires::( + multiplicands_0, + multiplicands_1, + addends, + ), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + !gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Non-canonical output should not pass constraints." + ); + } +} diff --git a/src/gates/comparison.rs b/src/gates/comparison.rs new file mode 100644 index 0000000..d10f3b8 --- /dev/null +++ b/src/gates/comparison.rs @@ -0,0 +1,710 @@ +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; + +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::{Field, Field64}; +use plonky2::gates::gate::Gate; +use plonky2::gates::packed_util::PackedEvaluableBase; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; +use plonky2::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; +use plonky2::util::{bits_u64, ceil_div_usize}; + +/// A gate for checking that one value is less than or equal to another. +#[derive(Clone, Debug)] +pub struct ComparisonGate, const D: usize> { + pub(crate) num_bits: usize, + pub(crate) num_chunks: usize, + _phantom: PhantomData, +} + +impl, const D: usize> ComparisonGate { + pub fn new(num_bits: usize, num_chunks: usize) -> Self { + debug_assert!(num_bits < bits_u64(F::ORDER)); + Self { + num_bits, + num_chunks, + _phantom: PhantomData, + } + } + + pub fn chunk_bits(&self) -> usize { + ceil_div_usize(self.num_bits, self.num_chunks) + } + + pub fn wire_first_input(&self) -> usize { + 0 + } + + pub fn wire_second_input(&self) -> usize { + 1 + } + + pub fn wire_result_bool(&self) -> usize { + 2 + } + + pub fn wire_most_significant_diff(&self) -> usize { + 3 + } + + pub fn wire_first_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 4 + chunk + } + + pub fn wire_second_chunk_val(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 4 + self.num_chunks + chunk + } + + pub fn wire_equality_dummy(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 4 + 2 * self.num_chunks + chunk + } + + pub fn wire_chunks_equal(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 4 + 3 * self.num_chunks + chunk + } + + pub fn wire_intermediate_value(&self, chunk: usize) -> usize { + debug_assert!(chunk < self.num_chunks); + 4 + 4 * self.num_chunks + chunk + } + + /// The `bit_index`th bit of 2^n - 1 + most_significant_diff. + pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize { + 4 + 5 * self.num_chunks + bit_index + } +} + +impl, const D: usize> Gate for ComparisonGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::Extension::from_canonical_usize(1 << self.chunk_bits()), + ); + + constraints.push(first_chunks_combined - first_input); + constraints.push(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = F::Extension::ZERO; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product: F::Extension = (0..chunk_size) + .map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + let second_product: F::Extension = (0..chunk_size) + .map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(first_product); + constraints.push(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal)); + constraints.push(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (F::Extension::ONE - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints.push(most_significant_diff - most_significant_diff_so_far); + + let most_significant_diff_bits: Vec = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &bit in &most_significant_diff_bits { + constraints.push(bit * (F::Extension::ONE - bit)); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO); + let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits()); + constraints.push((two_n + most_significant_diff) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]); + + constraints + } + + fn eval_unfiltered_base_one( + &self, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, + ) { + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits())); + let first_chunks_combined = + reduce_with_powers_ext_circuit(builder, &first_chunks, chunk_base); + let second_chunks_combined = + reduce_with_powers_ext_circuit(builder, &second_chunks, chunk_base); + + constraints.push(builder.sub_extension(first_chunks_combined, first_input)); + constraints.push(builder.sub_extension(second_chunks_combined, second_input)); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = builder.zero_extension(); + + let one = builder.one_extension(); + // Find the chosen chunk. + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let mut first_product = one; + let mut second_product = one; + for x in 0..chunk_size { + let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x)); + let first_diff = builder.sub_extension(first_chunks[i], x_f); + let second_diff = builder.sub_extension(second_chunks[i], x_f); + first_product = builder.mul_extension(first_product, first_diff); + second_product = builder.mul_extension(second_product, second_diff); + } + constraints.push(first_product); + constraints.push(second_product); + + let difference = builder.sub_extension(second_chunks[i], first_chunks[i]); + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + let diff_times_equal = builder.mul_extension(difference, equality_dummy); + let not_equal = builder.sub_extension(one, chunks_equal); + constraints.push(builder.sub_extension(diff_times_equal, not_equal)); + constraints.push(builder.mul_extension(chunks_equal, difference)); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far); + constraints.push(builder.sub_extension(intermediate_value, old_diff)); + + let not_equal = builder.sub_extension(one, chunks_equal); + let new_diff = builder.mul_extension(not_equal, difference); + most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff); + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + constraints + .push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far)); + + let most_significant_diff_bits: Vec> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &this_bit in &most_significant_diff_bits { + let inverse = builder.sub_extension(one, this_bit); + constraints.push(builder.mul_extension(this_bit, inverse)); + } + + let two = builder.two(); + let bits_combined = + reduce_with_powers_ext_circuit(builder, &most_significant_diff_bits, two); + let two_n = + builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits())); + let sum = builder.add_extension(two_n, most_significant_diff); + constraints.push(builder.sub_extension(sum, bits_combined)); + + // Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + constraints.push( + builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]), + ); + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = ComparisonGenerator:: { + row, + gate: self.clone(), + }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + 4 + 5 * self.num_chunks + (self.chunk_bits() + 1) + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << self.chunk_bits() + } + + fn num_constraints(&self) -> usize { + 6 + 5 * self.num_chunks + self.chunk_bits() + } +} + +impl, const D: usize> PackedEvaluableBase + for ComparisonGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + let first_input = vars.local_wires[self.wire_first_input()]; + let second_input = vars.local_wires[self.wire_second_input()]; + + // Get chunks and assert that they match + let first_chunks: Vec<_> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_first_chunk_val(i)]) + .collect(); + let second_chunks: Vec<_> = (0..self.num_chunks) + .map(|i| vars.local_wires[self.wire_second_chunk_val(i)]) + .collect(); + + let first_chunks_combined = reduce_with_powers( + &first_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + let second_chunks_combined = reduce_with_powers( + &second_chunks, + F::from_canonical_usize(1 << self.chunk_bits()), + ); + + yield_constr.one(first_chunks_combined - first_input); + yield_constr.one(second_chunks_combined - second_input); + + let chunk_size = 1 << self.chunk_bits(); + + let mut most_significant_diff_so_far = P::ZEROS; + + for i in 0..self.num_chunks { + // Range-check the chunks to be less than `chunk_size`. + let first_product: P = (0..chunk_size) + .map(|x| first_chunks[i] - F::from_canonical_usize(x)) + .product(); + let second_product: P = (0..chunk_size) + .map(|x| second_chunks[i] - F::from_canonical_usize(x)) + .product(); + yield_constr.one(first_product); + yield_constr.one(second_product); + + let difference = second_chunks[i] - first_chunks[i]; + let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)]; + let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)]; + + // Two constraints to assert that `chunks_equal` is valid. + yield_constr.one(difference * equality_dummy - (P::ONES - chunks_equal)); + yield_constr.one(chunks_equal * difference); + + // Update `most_significant_diff_so_far`. + let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)]; + yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far); + most_significant_diff_so_far = + intermediate_value + (P::ONES - chunks_equal) * difference; + } + + let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()]; + yield_constr.one(most_significant_diff - most_significant_diff_so_far); + + let most_significant_diff_bits: Vec<_> = (0..self.chunk_bits() + 1) + .map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)]) + .collect(); + + // Range-check the bits. + for &bit in &most_significant_diff_bits { + yield_constr.one(bit * (P::ONES - bit)); + } + + let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO); + let two_n = F::from_canonical_u64(1 << self.chunk_bits()); + yield_constr.one((most_significant_diff + two_n) - bits_combined); + + // Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1. + let result_bool = vars.local_wires[self.wire_result_bool()]; + yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]); + } +} + +#[derive(Debug)] +struct ComparisonGenerator, const D: usize> { + row: usize, + gate: ComparisonGate, +} + +impl, const D: usize> SimpleGenerator + for ComparisonGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| Target::wire(self.row, column); + + vec![ + local_target(self.gate.wire_first_input()), + local_target(self.gate.wire_second_input()), + ] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let first_input = get_local_wire(self.gate.wire_first_input()); + let second_input = get_local_wire(self.gate.wire_second_input()); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize); + + let chunk_size = 1 << self.gate.chunk_bits(); + let first_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let second_input_chunks: Vec = (0..self.gate.num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let chunks_equal: Vec = (0..self.gate.num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..self.gate.num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits()); + let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64(); + + let msd_bits_u64: Vec = (0..self.gate.chunk_bits() + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(tmp) + }) + .collect(); + let msd_bits: Vec = msd_bits_u64 + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect(); + + out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result); + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff()), + most_significant_diff, + ); + for i in 0..self.gate.num_chunks { + out_buffer.set_wire( + local_wire(self.gate.wire_first_chunk_val(i)), + first_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_second_chunk_val(i)), + second_input_chunks[i], + ); + out_buffer.set_wire( + local_wire(self.gate.wire_equality_dummy(i)), + equality_dummies[i], + ); + out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]); + out_buffer.set_wire( + local_wire(self.gate.wire_intermediate_value(i)), + intermediate_values[i], + ); + } + for i in 0..self.gate.chunk_bits() + 1 { + out_buffer.set_wire( + local_wire(self.gate.wire_most_significant_diff_bit(i)), + msd_bits[i], + ); + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{PrimeField64, Sample}; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + #[test] + fn wire_indices() { + type CG = ComparisonGate; + let num_bits = 40; + let num_chunks = 5; + + let gate = CG { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + + assert_eq!(gate.wire_first_input(), 0); + assert_eq!(gate.wire_second_input(), 1); + assert_eq!(gate.wire_result_bool(), 2); + assert_eq!(gate.wire_most_significant_diff(), 3); + assert_eq!(gate.wire_first_chunk_val(0), 4); + assert_eq!(gate.wire_first_chunk_val(4), 8); + assert_eq!(gate.wire_second_chunk_val(0), 9); + assert_eq!(gate.wire_second_chunk_val(4), 13); + assert_eq!(gate.wire_equality_dummy(0), 14); + assert_eq!(gate.wire_equality_dummy(4), 18); + assert_eq!(gate.wire_chunks_equal(0), 19); + assert_eq!(gate.wire_chunks_equal(4), 23); + assert_eq!(gate.wire_intermediate_value(0), 24); + assert_eq!(gate.wire_intermediate_value(4), 28); + assert_eq!(gate.wire_most_significant_diff_bit(0), 29); + assert_eq!(gate.wire_most_significant_diff_bit(8), 37); + } + + #[test] + fn low_degree() { + let num_bits = 40; + let num_chunks = 5; + + test_low_degree::(ComparisonGate::<_, 4>::new(num_bits, num_chunks)) + } + + #[test] + fn eval_fns() -> Result<()> { + let num_bits = 40; + let num_chunks = 5; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + test_eval_fns::(ComparisonGate::<_, 2>::new(num_bits, num_chunks)) + } + + #[test] + fn test_gate_constraint() { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type FF = >::FE; + + let num_bits = 40; + let num_chunks = 5; + let chunk_bits = num_bits / num_chunks; + + // Returns the local wires for a comparison gate given the two inputs. + let get_wires = |first_input: F, second_input: F| -> Vec { + let mut v = Vec::new(); + + let first_input_u64 = first_input.to_canonical_u64(); + let second_input_u64 = second_input.to_canonical_u64(); + + let result_bool = F::from_bool(first_input_u64 <= second_input_u64); + + let chunk_size = 1 << chunk_bits; + let mut first_input_chunks: Vec = (0..num_chunks) + .scan(first_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + let mut second_input_chunks: Vec = (0..num_chunks) + .scan(second_input_u64, |acc, _| { + let tmp = *acc % chunk_size; + *acc /= chunk_size; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + let mut chunks_equal: Vec = (0..num_chunks) + .map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i])) + .collect(); + let mut equality_dummies: Vec = first_input_chunks + .iter() + .zip(second_input_chunks.iter()) + .map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) }) + .collect(); + + let mut most_significant_diff_so_far = F::ZERO; + let mut intermediate_values = Vec::new(); + for i in 0..num_chunks { + if first_input_chunks[i] != second_input_chunks[i] { + most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i]; + intermediate_values.push(F::ZERO); + } else { + intermediate_values.push(most_significant_diff_so_far); + } + } + let most_significant_diff = most_significant_diff_so_far; + + let two_n_plus_msd = + (1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64(); + let mut msd_bits: Vec = (0..chunk_bits + 1) + .scan(two_n_plus_msd, |acc, _| { + let tmp = *acc % 2; + *acc /= 2; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + v.push(first_input); + v.push(second_input); + v.push(result_bool); + v.push(most_significant_diff); + v.append(&mut first_input_chunks); + v.append(&mut second_input_chunks); + v.append(&mut equality_dummies); + v.append(&mut chunks_equal); + v.append(&mut intermediate_values); + v.append(&mut msd_bits); + + v.iter().map(|&x| x.into()).collect() + }; + + let mut rng = OsRng; + let max: u64 = 1 << (num_bits - 1); + let first_input_u64 = rng.gen_range(0..max); + let second_input_u64 = { + let mut val = rng.gen_range(0..max); + while val < first_input_u64 { + val = rng.gen_range(0..max); + } + val + }; + + let first_input = F::from_canonical_u64(first_input_u64); + let second_input = F::from_canonical_u64(second_input_u64); + + let less_than_gate = ComparisonGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let less_than_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, second_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + less_than_gate + .eval_unfiltered(less_than_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + + let equal_gate = ComparisonGate:: { + num_bits, + num_chunks, + _phantom: PhantomData, + }; + let equal_vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(first_input, first_input), + public_inputs_hash: &HashOut::rand(), + }; + assert!( + equal_gate + .eval_unfiltered(equal_vars) + .iter() + .all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/gates/mod.rs b/src/gates/mod.rs new file mode 100644 index 0000000..1880b16 --- /dev/null +++ b/src/gates/mod.rs @@ -0,0 +1,5 @@ +pub mod add_many_u32; +pub mod arithmetic_u32; +pub mod comparison; +pub mod range_check_u32; +pub mod subtraction_u32; diff --git a/src/gates/range_check_u32.rs b/src/gates/range_check_u32.rs new file mode 100644 index 0000000..55faa6c --- /dev/null +++ b/src/gates/range_check_u32.rs @@ -0,0 +1,307 @@ +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; + +use plonky2::field::extension::Extendable; +use plonky2::field::types::Field; +use plonky2::gates::gate::Gate; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit}; +use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase}; +use plonky2::util::ceil_div_usize; + +/// A gate which can decompose a number into base B little-endian limbs. +#[derive(Copy, Clone, Debug)] +pub struct U32RangeCheckGate, const D: usize> { + pub num_input_limbs: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32RangeCheckGate { + pub fn new(num_input_limbs: usize) -> Self { + Self { + num_input_limbs, + _phantom: PhantomData, + } + } + + pub const AUX_LIMB_BITS: usize = 2; + pub const BASE: usize = 1 << Self::AUX_LIMB_BITS; + + fn aux_limbs_per_input_limb(&self) -> usize { + ceil_div_usize(32, Self::AUX_LIMB_BITS) + } + pub fn wire_ith_input_limb(&self, i: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + i + } + pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_input_limbs); + debug_assert!(j < self.aux_limbs_per_input_limb()); + self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j + } +} + +impl, const D: usize> Gate for U32RangeCheckGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = F::Extension::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + constraints.push(computed_sum - input_limb); + for aux_limb in aux_limbs { + constraints.push( + (0..Self::BASE) + .map(|i| aux_limb - F::Extension::from_canonical_usize(i)) + .product(), + ); + } + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + vars: EvaluationVarsBase, + mut yield_constr: StridedConstraintConsumer, + ) { + let base = F::from_canonical_usize(Self::BASE); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers(&aux_limbs, base); + + yield_constr.one(computed_sum - input_limb); + for aux_limb in aux_limbs { + yield_constr.one( + (0..Self::BASE) + .map(|i| aux_limb - F::from_canonical_usize(i)) + .product(), + ); + } + } + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + + let base = builder.constant(F::from_canonical_usize(Self::BASE)); + for i in 0..self.num_input_limbs { + let input_limb = vars.local_wires[self.wire_ith_input_limb(i)]; + let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb()) + .map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)]) + .collect(); + let computed_sum = reduce_with_powers_ext_circuit(builder, &aux_limbs, base); + + constraints.push(builder.sub_extension(computed_sum, input_limb)); + for aux_limb in aux_limbs { + constraints.push({ + let mut acc = builder.one_extension(); + (0..Self::BASE).for_each(|i| { + // We update our accumulator as: + // acc' = acc (x - i) + // = acc x + (-i) acc + // Since -i is constant, we can do this in one arithmetic_extension call. + let neg_i = -F::from_canonical_usize(i); + acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc) + }); + acc + }); + } + } + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + let gen = U32RangeCheckGenerator { gate: *self, row }; + vec![Box::new(gen.adapter())] + } + + fn num_wires(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } + + fn num_constants(&self) -> usize { + 0 + } + + // Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1). + fn degree(&self) -> usize { + Self::BASE + } + + // 1 for checking the each sum of aux limbs, plus a range check for each aux limb. + fn num_constraints(&self) -> usize { + self.num_input_limbs * (1 + self.aux_limbs_per_input_limb()) + } +} + +#[derive(Debug)] +pub struct U32RangeCheckGenerator, const D: usize> { + gate: U32RangeCheckGate, + row: usize, +} + +impl, const D: usize> SimpleGenerator + for U32RangeCheckGenerator +{ + fn dependencies(&self) -> Vec { + let num_input_limbs = self.gate.num_input_limbs; + (0..num_input_limbs) + .map(|i| Target::wire(self.row, self.gate.wire_ith_input_limb(i))) + .collect() + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let num_input_limbs = self.gate.num_input_limbs; + for i in 0..num_input_limbs { + let sum_value = witness + .get_target(Target::wire(self.row, self.gate.wire_ith_input_limb(i))) + .to_canonical_u64() as u32; + + let base = U32RangeCheckGate::::BASE as u32; + let limbs = (0..self.gate.aux_limbs_per_input_limb()) + .map(|j| Target::wire(self.row, self.gate.wire_ith_input_limb_jth_aux_limb(i, j))); + let limbs_value = (0..self.gate.aux_limbs_per_input_limb()) + .scan(sum_value, |acc, _| { + let tmp = *acc % base; + *acc /= base; + Some(F::from_canonical_u32(tmp)) + }) + .collect::>(); + + for (b, b_value) in limbs.zip(limbs_value) { + out_buffer.set_target(b, b_value); + } + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use itertools::unfold; + use plonky2::field::extension::quartic::QuarticExtension; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{Field, Sample}; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + #[test] + fn low_degree() { + test_low_degree::(U32RangeCheckGate::new(8)) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32RangeCheckGate::new(8)) + } + + fn test_gate_constraint(input_limbs: Vec) { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const AUX_LIMB_BITS: usize = 2; + const BASE: usize = 1 << AUX_LIMB_BITS; + const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS); + + fn get_wires(input_limbs: Vec) -> Vec { + let num_input_limbs = input_limbs.len(); + let mut v = Vec::new(); + + for i in 0..num_input_limbs { + let input_limb = input_limbs[i]; + + let split_to_limbs = |mut val, num| { + unfold((), move |_| { + let ret = val % (BASE as u64); + val /= BASE as u64; + Some(ret) + }) + .take(num) + .map(F::from_canonical_u64) + }; + + let mut aux_limbs: Vec<_> = + split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect(); + + v.append(&mut aux_limbs); + } + + input_limbs + .iter() + .cloned() + .map(F::from_canonical_u64) + .chain(v.iter().cloned()) + .map(|x| x.into()) + .collect() + } + + let gate = U32RangeCheckGate:: { + num_input_limbs: 8, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(input_limbs), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } + + #[test] + fn test_gate_constraint_good() { + let mut rng = OsRng; + let input_limbs: Vec<_> = (0..8).map(|_| rng.gen::() as u64).collect(); + + test_gate_constraint(input_limbs); + } + + #[test] + #[should_panic] + fn test_gate_constraint_bad() { + let mut rng = OsRng; + let input_limbs: Vec<_> = (0..8).map(|_| rng.gen()).collect(); + + test_gate_constraint(input_limbs); + } +} diff --git a/src/gates/subtraction_u32.rs b/src/gates/subtraction_u32.rs new file mode 100644 index 0000000..01f55e0 --- /dev/null +++ b/src/gates/subtraction_u32.rs @@ -0,0 +1,445 @@ +use alloc::boxed::Box; +use alloc::string::String; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; + +use plonky2::field::extension::Extendable; +use plonky2::field::packed::PackedField; +use plonky2::field::types::Field; +use plonky2::gates::gate::Gate; +use plonky2::gates::packed_util::PackedEvaluableBase; +use plonky2::gates::util::StridedConstraintConsumer; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator}; +use plonky2::iop::target::Target; +use plonky2::iop::wire::Wire; +use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::CircuitConfig; +use plonky2::plonk::vars::{ + EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch, + EvaluationVarsBasePacked, +}; + +/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns +/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked. +#[derive(Copy, Clone, Debug)] +pub struct U32SubtractionGate, const D: usize> { + pub num_ops: usize, + _phantom: PhantomData, +} + +impl, const D: usize> U32SubtractionGate { + pub fn new_from_config(config: &CircuitConfig) -> Self { + Self { + num_ops: Self::num_ops(config), + _phantom: PhantomData, + } + } + + pub(crate) fn num_ops(config: &CircuitConfig) -> usize { + let wires_per_op = 5 + Self::num_limbs(); + let routed_wires_per_op = 5; + (config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op) + } + + pub fn wire_ith_input_x(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + } + pub fn wire_ith_input_y(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 1 + } + pub fn wire_ith_input_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 2 + } + + pub fn wire_ith_output_result(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 3 + } + pub fn wire_ith_output_borrow(&self, i: usize) -> usize { + debug_assert!(i < self.num_ops); + 5 * i + 4 + } + + pub fn limb_bits() -> usize { + 2 + } + // We have limbs for the 32 bits of `output_result`. + pub fn num_limbs() -> usize { + 32 / Self::limb_bits() + } + + pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize { + debug_assert!(i < self.num_ops); + debug_assert!(j < Self::num_limbs()); + 5 * self.num_ops + Self::num_limbs() * i + j + } +} + +impl, const D: usize> Gate for U32SubtractionGate { + fn id(&self) -> String { + format!("{self:?}") + } + + fn eval_unfiltered(&self, vars: EvaluationVars) -> Vec { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::Extension::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + constraints.push(output_result - (result_initial + base * output_borrow)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = F::Extension::ZERO; + let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::Extension::from_canonical_usize(x)) + .product(); + constraints.push(product); + + combined_limbs = limb_base * combined_limbs + this_limb; + } + constraints.push(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + constraints.push(output_borrow * (F::Extension::ONE - output_borrow)); + } + + constraints + } + + fn eval_unfiltered_base_one( + &self, + _vars: EvaluationVarsBase, + _yield_constr: StridedConstraintConsumer, + ) { + panic!("use eval_unfiltered_base_packed instead"); + } + + fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch) -> Vec { + self.eval_unfiltered_base_batch_packed(vars_base) + } + + fn eval_unfiltered_circuit( + &self, + builder: &mut CircuitBuilder, + vars: EvaluationTargets, + ) -> Vec> { + let mut constraints = Vec::with_capacity(self.num_constraints()); + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let diff = builder.sub_extension(input_x, input_y); + let result_initial = builder.sub_extension(diff, input_borrow); + let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64)); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + let computed_output = builder.mul_add_extension(base, output_borrow, result_initial); + constraints.push(builder.sub_extension(output_result, computed_output)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = builder.zero_extension(); + let limb_base = builder + .constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits())); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let mut product = builder.one_extension(); + for x in 0..max_limb { + let x_target = + builder.constant_extension(F::Extension::from_canonical_usize(x)); + let diff = builder.sub_extension(this_limb, x_target); + product = builder.mul_extension(product, diff); + } + constraints.push(product); + + combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb); + } + constraints.push(builder.sub_extension(combined_limbs, output_result)); + + // Range-check output_borrow to be one bit. + let one = builder.one_extension(); + let not_borrow = builder.sub_extension(one, output_borrow); + constraints.push(builder.mul_extension(output_borrow, not_borrow)); + } + + constraints + } + + fn generators(&self, row: usize, _local_constants: &[F]) -> Vec>> { + (0..self.num_ops) + .map(|i| { + let g: Box> = Box::new( + U32SubtractionGenerator { + gate: *self, + row, + i, + _phantom: PhantomData, + } + .adapter(), + ); + g + }) + .collect() + } + + fn num_wires(&self) -> usize { + self.num_ops * (5 + Self::num_limbs()) + } + + fn num_constants(&self) -> usize { + 0 + } + + fn degree(&self) -> usize { + 1 << Self::limb_bits() + } + + fn num_constraints(&self) -> usize { + self.num_ops * (3 + Self::num_limbs()) + } +} + +impl, const D: usize> PackedEvaluableBase + for U32SubtractionGate +{ + fn eval_unfiltered_base_packed>( + &self, + vars: EvaluationVarsBasePacked

, + mut yield_constr: StridedConstraintConsumer

, + ) { + for i in 0..self.num_ops { + let input_x = vars.local_wires[self.wire_ith_input_x(i)]; + let input_y = vars.local_wires[self.wire_ith_input_y(i)]; + let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)]; + + let result_initial = input_x - input_y - input_borrow; + let base = F::from_canonical_u64(1 << 32u64); + + let output_result = vars.local_wires[self.wire_ith_output_result(i)]; + let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)]; + + yield_constr.one(output_result - (result_initial + output_borrow * base)); + + // Range-check output_result to be at most 32 bits. + let mut combined_limbs = P::ZEROS; + let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits()); + for j in (0..Self::num_limbs()).rev() { + let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)]; + let max_limb = 1 << Self::limb_bits(); + let product = (0..max_limb) + .map(|x| this_limb - F::from_canonical_usize(x)) + .product(); + yield_constr.one(product); + + combined_limbs = combined_limbs * limb_base + this_limb; + } + yield_constr.one(combined_limbs - output_result); + + // Range-check output_borrow to be one bit. + yield_constr.one(output_borrow * (P::ONES - output_borrow)); + } + } +} + +#[derive(Clone, Debug)] +struct U32SubtractionGenerator, const D: usize> { + gate: U32SubtractionGate, + row: usize, + i: usize, + _phantom: PhantomData, +} + +impl, const D: usize> SimpleGenerator + for U32SubtractionGenerator +{ + fn dependencies(&self) -> Vec { + let local_target = |column| Target::wire(self.row, column); + + vec![ + local_target(self.gate.wire_ith_input_x(self.i)), + local_target(self.gate.wire_ith_input_y(self.i)), + local_target(self.gate.wire_ith_input_borrow(self.i)), + ] + } + + fn run_once(&self, witness: &PartitionWitness, out_buffer: &mut GeneratedValues) { + let local_wire = |column| Wire { + row: self.row, + column, + }; + + let get_local_wire = |column| witness.get_wire(local_wire(column)); + + let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i)); + let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i)); + let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i)); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i)); + let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i)); + + out_buffer.set_wire(output_result_wire, output_result); + out_buffer.set_wire(output_borrow_wire, output_borrow); + + let output_result_u64 = output_result.to_canonical_u64(); + + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << U32SubtractionGate::::limb_bits(); + let output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + for j in 0..num_limbs { + let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j)); + out_buffer.set_wire(wire, output_limbs[j]); + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use plonky2::field::extension::quartic::QuarticExtension; + use plonky2::field::goldilocks_field::GoldilocksField; + use plonky2::field::types::{PrimeField64, Sample}; + use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree}; + use plonky2::hash::hash_types::HashOut; + use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + use rand::rngs::OsRng; + use rand::Rng; + + use super::*; + + #[test] + fn low_degree() { + test_low_degree::(U32SubtractionGate:: { + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn eval_fns() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + test_eval_fns::(U32SubtractionGate:: { + num_ops: 3, + _phantom: PhantomData, + }) + } + + #[test] + fn test_gate_constraint() { + type F = GoldilocksField; + type FF = QuarticExtension; + const D: usize = 4; + const NUM_U32_SUBTRACTION_OPS: usize = 3; + + fn get_wires(inputs_x: Vec, inputs_y: Vec, borrows: Vec) -> Vec { + let mut v0 = Vec::new(); + let mut v1 = Vec::new(); + + let limb_bits = U32SubtractionGate::::limb_bits(); + let num_limbs = U32SubtractionGate::::num_limbs(); + let limb_base = 1 << limb_bits; + for c in 0..NUM_U32_SUBTRACTION_OPS { + let input_x = F::from_canonical_u64(inputs_x[c]); + let input_y = F::from_canonical_u64(inputs_y[c]); + let input_borrow = F::from_canonical_u64(borrows[c]); + + let result_initial = input_x - input_y - input_borrow; + let result_initial_u64 = result_initial.to_canonical_u64(); + let output_borrow = if result_initial_u64 > 1 << 32u64 { + F::ONE + } else { + F::ZERO + }; + + let base = F::from_canonical_u64(1 << 32u64); + let output_result = result_initial + base * output_borrow; + + let output_result_u64 = output_result.to_canonical_u64(); + + let mut output_limbs: Vec<_> = (0..num_limbs) + .scan(output_result_u64, |acc, _| { + let tmp = *acc % limb_base; + *acc /= limb_base; + Some(F::from_canonical_u64(tmp)) + }) + .collect(); + + v0.push(input_x); + v0.push(input_y); + v0.push(input_borrow); + v0.push(output_result); + v0.push(output_borrow); + v1.append(&mut output_limbs); + } + + v0.iter().chain(v1.iter()).map(|&x| x.into()).collect() + } + + let mut rng = OsRng; + let inputs_x = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let inputs_y = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| rng.gen::() as u64) + .collect(); + let borrows = (0..NUM_U32_SUBTRACTION_OPS) + .map(|_| (rng.gen::() % 2) as u64) + .collect(); + + let gate = U32SubtractionGate:: { + num_ops: NUM_U32_SUBTRACTION_OPS, + _phantom: PhantomData, + }; + + let vars = EvaluationVars { + local_constants: &[], + local_wires: &get_wires(inputs_x, inputs_y, borrows), + public_inputs_hash: &HashOut::rand(), + }; + + assert!( + gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()), + "Gate constraints are not satisfied." + ); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..2d8d07f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,8 @@ +#![allow(clippy::needless_range_loop)] +#![no_std] + +extern crate alloc; + +pub mod gadgets; +pub mod gates; +pub mod witness; diff --git a/src/witness.rs b/src/witness.rs new file mode 100644 index 0000000..cf308d2 --- /dev/null +++ b/src/witness.rs @@ -0,0 +1,33 @@ +use plonky2::field::types::{Field, PrimeField64}; +use plonky2::iop::generator::GeneratedValues; +use plonky2::iop::witness::{Witness, WitnessWrite}; + +use crate::gadgets::arithmetic_u32::U32Target; + +pub trait WitnessU32: Witness { + fn set_u32_target(&mut self, target: U32Target, value: u32); + fn get_u32_target(&self, target: U32Target) -> (u32, u32); +} + +impl, F: PrimeField64> WitnessU32 for T { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)); + } + + fn get_u32_target(&self, target: U32Target) -> (u32, u32) { + let x_u64 = self.get_target(target.0).to_canonical_u64(); + let low = x_u64 as u32; + let high = (x_u64 >> 32) as u32; + (low, high) + } +} + +pub trait GeneratedValuesU32 { + fn set_u32_target(&mut self, target: U32Target, value: u32); +} + +impl GeneratedValuesU32 for GeneratedValues { + fn set_u32_target(&mut self, target: U32Target, value: u32) { + self.set_target(target.0, F::from_canonical_u32(value)) + } +}