diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4811ae6..b1dc311 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: name: Test runs-on: ubuntu-latest env: - RUSTFLAGS: -Dwarnings + RUSTFLAGS: -Dwarnings --cfg ci strategy: matrix: rust: diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f1cfca..5da5270 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,9 @@ ### Features +- [\#79](https://github.com/arkworks-rs/r1cs-std/pull/79) Move `NonNativeFieldVar` from `ark-nonnative` to `ark-r1cs-std`. - [\#76](https://github.com/arkworks-rs/r1cs-std/pull/76) Implement `ToBytesGadget` for `Vec`. +- [nonnative/\#45](https://github.com/arkworks-rs/nonnative/pull/45) Add `new_witness_with_le_bits` which returns the bits used during variable allocation. ### Improvements @@ -32,10 +34,11 @@ - [\#60](https://github.com/arkworks-rs/r1cs-std/pull/60) Rename `AllocatedBit` to `AllocatedBool` for consistency with the `Boolean` variable. You can update downstream usage with `grep -rl 'AllocatedBit' . | xargs env LANG=C env LC_CTYPE=C sed -i '' 's/AllocatedBit/AllocatedBool/g'`. - [\#65](https://github.com/arkworks-rs/r1cs-std/pull/65) Rename `Radix2Domain` in `r1cs-std` to `Radix2DomainVar`. +- [nonnative/\#43](https://github.com/arkworks-rs/nonnative/pull/43) Add padding to allocated nonnative element's `to_bytes`. ### Features -- [\#53](https://github.com/arkworks-rs/r1cs-std/pull/53) Add univariate evaluation domain and Lagrange interpolation. +- [\#53](https://github.com/arkworks-rs/r1cs-std/pull/53) Add univariate evaluation domain and Lagrange interpolation. ### Improvements diff --git a/Cargo.toml b/Cargo.toml index ad83144..d760eeb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,12 +28,26 @@ derivative = { version = "2", features = ["use_core"] } tracing = { version = "0.1", default-features = false, features = [ "attributes" ] } num-bigint = {version = "0.4", default-features = false } num-traits = {version = "0.2", default-features = false } +num-integer = { version = "0.1.44", default-features = false } [dev-dependencies] ark-test-curves = { version = "^0.3.0", default-features = false, features = ["bls12_381_scalar_field", "bls12_381_curve", "mnt4_753_scalar_field"] } ark-poly = { version = "^0.3.0", default-features = false } +paste = "1.0" +ark-bls12-377 = { version = "^0.3.0", features = ["curve"], default-features = false } +ark-bls12-381 = { version = "^0.3.0", features = ["curve"], default-features = false } +ark-mnt4-298 = { version = "^0.3.0", features = ["curve"], default-features = false } +ark-mnt4-753 = { version = "^0.3.0", features = ["curve"], default-features = false } +ark-mnt6-298 = { version = "^0.3.0", default-features = false } +ark-mnt6-753 = { version = "^0.3.0", default-features = false } +ark-pallas = { version = "^0.3.0", features = ["curve"], default-features = false } [features] default = ["std"] std = [ "ark-ff/std", "ark-relations/std", "ark-std/std", "num-bigint/std" ] parallel = [ "std", "ark-ff/parallel", "ark-std/parallel"] + +[[bench]] +name = "nonnative-bench" +path = "benches/bench.rs" +harness = false \ No newline at end of file diff --git a/benches/bench.rs b/benches/bench.rs new file mode 100644 index 0000000..8765a01 --- /dev/null +++ b/benches/bench.rs @@ -0,0 +1,235 @@ +use ark_ff::PrimeField; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; +use ark_r1cs_std::{alloc::AllocVar, eq::EqGadget, fields::FieldVar}; +use ark_relations::{ + ns, + r1cs::{ConstraintSystem, ConstraintSystemRef, OptimizationGoal}, +}; +use ark_std::rand::RngCore; + +const NUM_REPETITIONS: usize = 1; + +fn get_density(cs: &ConstraintSystemRef) -> usize { + match cs { + ConstraintSystemRef::None => panic!("Constraint system is none."), + ConstraintSystemRef::CS(r) => { + let mut cs_bak = r.borrow().clone(); + + cs_bak.finalize(); + let matrices = cs_bak.to_matrices().unwrap(); + + matrices.a_num_non_zero + matrices.b_num_non_zero + matrices.c_num_non_zero + } + } +} + +fn allocation( + cs: ConstraintSystemRef, + rng: &mut R, +) -> (usize, usize) { + let a_native = TargetField::rand(rng); + + let constraints_before = cs.num_constraints(); + let nonzeros_before = get_density(&cs); + + // There will be a check that ensures it has the reasonable number of bits + let _ = NonNativeFieldVar::::new_witness(ns!(cs, "alloc a"), || { + Ok(a_native) + }) + .unwrap(); + + let constraints_after = cs.num_constraints(); + let nonzeros_after = get_density(&cs); + + return ( + constraints_after - constraints_before, + nonzeros_after - nonzeros_before, + ); +} + +fn addition( + cs: ConstraintSystemRef, + rng: &mut R, +) -> (usize, usize) { + let a_native = TargetField::rand(rng); + let a = NonNativeFieldVar::::new_witness(ns!(cs, "alloc a"), || { + Ok(a_native) + }) + .unwrap(); + + let b_native = TargetField::rand(rng); + let b = NonNativeFieldVar::::new_witness(ns!(cs, "alloc b"), || { + Ok(b_native) + }) + .unwrap(); + + let constraints_before = cs.num_constraints(); + let nonzeros_before = get_density(&cs); + + let _ = &a + &b; + + let constraints_after = cs.num_constraints(); + let nonzeros_after = get_density(&cs); + + return ( + constraints_after - constraints_before, + nonzeros_after - nonzeros_before, + ); +} + +fn equality( + cs: ConstraintSystemRef, + rng: &mut R, +) -> (usize, usize) { + let a_native = TargetField::rand(rng); + let a1 = NonNativeFieldVar::::new_witness(ns!(cs, "alloc a1"), || { + Ok(a_native) + }) + .unwrap(); + let a2 = NonNativeFieldVar::::new_witness(ns!(cs, "alloc a2"), || { + Ok(a_native) + }) + .unwrap(); + + let constraints_before = cs.num_constraints(); + let nonzeros_before = get_density(&cs); + + a1.enforce_equal(&a2).unwrap(); + + let constraints_after = cs.num_constraints(); + let nonzeros_after = get_density(&cs); + + return ( + constraints_after - constraints_before, + nonzeros_after - nonzeros_before, + ); +} + +fn multiplication( + cs: ConstraintSystemRef, + rng: &mut R, +) -> (usize, usize) { + let a_native = TargetField::rand(rng); + let a = NonNativeFieldVar::::new_witness(ns!(cs, "initial a"), || { + Ok(a_native) + }) + .unwrap(); + + let b_native = TargetField::rand(rng); + let b = NonNativeFieldVar::::new_witness(ns!(cs, "initial b"), || { + Ok(b_native) + }) + .unwrap(); + + let constraints_before = cs.num_constraints(); + let nonzeros_before = get_density(&cs); + + let _ = &a * &b; + + let constraints_after = cs.num_constraints(); + let nonzeros_after = get_density(&cs); + + return ( + constraints_after - constraints_before, + nonzeros_after - nonzeros_before, + ); +} + +fn inverse( + cs: ConstraintSystemRef, + rng: &mut R, +) -> (usize, usize) { + let num_native = TargetField::rand(rng); + let num = NonNativeFieldVar::::new_witness(ns!(cs, "alloc"), || { + Ok(num_native) + }) + .unwrap(); + + let constraints_before = cs.num_constraints(); + let nonzeros_before = get_density(&cs); + + let _ = num.inverse().unwrap(); + + let constraints_after = cs.num_constraints(); + let nonzeros_after = get_density(&cs); + + return ( + constraints_after - constraints_before, + nonzeros_after - nonzeros_before, + ); +} + +macro_rules! nonnative_bench_individual { + ($bench_method:ident, $bench_name:ident, $bench_target_field:ty, $bench_base_field:ty) => { + let rng = &mut ark_std::test_rng(); + let mut num_constraints = 0; + let mut num_nonzeros = 0; + for _ in 0..NUM_REPETITIONS { + let cs_sys = ConstraintSystem::<$bench_base_field>::new(); + let cs = ConstraintSystemRef::new(cs_sys); + cs.set_optimization_goal(OptimizationGoal::Constraints); + + let (cur_constraints, cur_nonzeros) = + $bench_method::<$bench_target_field, $bench_base_field, _>(cs.clone(), rng); + + num_constraints += cur_constraints; + num_nonzeros += cur_nonzeros; + + assert!(cs.is_satisfied().unwrap()); + } + let average_constraints = num_constraints / NUM_REPETITIONS; + let average_nonzeros = num_nonzeros / NUM_REPETITIONS; + println!( + "{} takes: {} constraints, {} non-zeros", + stringify!($bench_method), + average_constraints, + average_nonzeros, + ); + }; +} + +macro_rules! nonnative_bench { + ($bench_name:ident, $bench_target_field:ty, $bench_base_field:ty) => { + println!( + "For {} to simulate {}", + stringify!($bench_base_field), + stringify!($bench_target_field), + ); + nonnative_bench_individual!( + allocation, + $bench_name, + $bench_target_field, + $bench_base_field + ); + nonnative_bench_individual!( + addition, + $bench_name, + $bench_target_field, + $bench_base_field + ); + nonnative_bench_individual!( + multiplication, + $bench_name, + $bench_target_field, + $bench_base_field + ); + nonnative_bench_individual!( + equality, + $bench_name, + $bench_target_field, + $bench_base_field + ); + nonnative_bench_individual!(inverse, $bench_name, $bench_target_field, $bench_base_field); + println!("----------------------") + }; +} + +fn main() { + nonnative_bench!(MNT46Small, ark_mnt4_298::Fr, ark_mnt6_298::Fr); + nonnative_bench!(MNT64Small, ark_mnt6_298::Fr, ark_mnt4_298::Fr); + nonnative_bench!(MNT46Big, ark_mnt4_753::Fr, ark_mnt6_753::Fr); + nonnative_bench!(MNT64Big, ark_mnt6_753::Fr, ark_mnt4_753::Fr); + nonnative_bench!(BLS12MNT4Small, ark_bls12_381::Fr, ark_mnt4_298::Fr); + nonnative_bench!(BLS12, ark_bls12_381::Fq, ark_bls12_381::Fr); + nonnative_bench!(MNT6BigMNT4Small, ark_mnt6_753::Fr, ark_mnt4_298::Fr); +} diff --git a/src/fields/mod.rs b/src/fields/mod.rs index 27f6c64..a199d73 100644 --- a/src/fields/mod.rs +++ b/src/fields/mod.rs @@ -20,6 +20,10 @@ pub mod quadratic_extension; /// That is, it implements the R1CS equivalent of `ark_ff::Fp*`. pub mod fp; +/// This module contains a generic implementation of "nonnative" prime field variables. +/// It emulates `Fp` arithmetic using `Fq` operations, where `p != q`. +pub mod nonnative; + /// This module contains a generic implementation of the degree-12 tower /// extension field. That is, it implements the R1CS equivalent of /// `ark_ff::Fp12` diff --git a/src/fields/nonnative/allocated_field_var.rs b/src/fields/nonnative/allocated_field_var.rs new file mode 100644 index 0000000..4c6220b --- /dev/null +++ b/src/fields/nonnative/allocated_field_var.rs @@ -0,0 +1,919 @@ +use super::params::{get_params, OptimizationType}; +use super::reduce::{bigint_to_basefield, limbs_to_bigint, Reducer}; +use super::AllocatedNonNativeFieldMulResultVar; +use crate::fields::fp::FpVar; +use crate::prelude::*; +use crate::ToConstraintFieldGadget; +use ark_ff::{BigInteger, FpParameters, PrimeField}; +use ark_relations::r1cs::{OptimizationGoal, Result as R1CSResult}; +use ark_relations::{ + ns, + r1cs::{ConstraintSystemRef, Namespace, SynthesisError}, +}; +use ark_std::cmp::{max, min}; +use ark_std::marker::PhantomData; +use ark_std::{borrow::Borrow, vec, vec::Vec}; + +/// The allocated version of `NonNativeFieldVar` (introduced below) +#[derive(Debug)] +#[must_use] +pub struct AllocatedNonNativeFieldVar { + /// Constraint system reference + pub cs: ConstraintSystemRef, + /// The limbs, each of which is a BaseField gadget. + pub limbs: Vec>, + /// Number of additions done over this gadget, using which the gadget decides when to reduce. + pub num_of_additions_over_normal_form: BaseField, + /// Whether the limb representation is the normal form (using only the bits specified in the parameters, and the representation is strictly within the range of TargetField). + pub is_in_the_normal_form: bool, + #[doc(hidden)] + pub target_phantom: PhantomData, +} + +impl + AllocatedNonNativeFieldVar +{ + /// Return cs + pub fn cs(&self) -> ConstraintSystemRef { + self.cs.clone() + } + + /// Obtain the value of limbs + pub fn limbs_to_value( + limbs: Vec, + optimization_type: OptimizationType, + ) -> TargetField { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + optimization_type, + ); + + let mut base_repr: ::BigInt = TargetField::one().into_repr(); + + // Convert 2^{(params.bits_per_limb - 1)} into the TargetField and then double the base + // This is because 2^{(params.bits_per_limb)} might indeed be larger than the target field's prime. + base_repr.muln((params.bits_per_limb - 1) as u32); + let mut base: TargetField = TargetField::from_repr(base_repr).unwrap(); + base = base + &base; + + let mut result = TargetField::zero(); + let mut power = TargetField::one(); + + for limb in limbs.iter().rev() { + let mut val = TargetField::zero(); + let mut cur = TargetField::one(); + + for bit in limb.into_repr().to_bits_be().iter().rev() { + if *bit { + val += &cur; + } + cur.double_in_place(); + } + + result += &(val * power); + power *= &base; + } + + result + } + + /// Obtain the value of a nonnative field element + pub fn value(&self) -> R1CSResult { + let mut limbs = Vec::new(); + for limb in self.limbs.iter() { + limbs.push(limb.value()?); + } + + Ok(Self::limbs_to_value(limbs, self.get_optimization_type())) + } + + /// Obtain the nonnative field element of a constant value + pub fn constant(cs: ConstraintSystemRef, value: TargetField) -> R1CSResult { + let optimization_type = match cs.optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + }; + + let limbs_value = Self::get_limbs_representations(&value, optimization_type)?; + + let mut limbs = Vec::new(); + + for limb_value in limbs_value.iter() { + limbs.push(FpVar::::new_constant( + ns!(cs, "limb"), + limb_value, + )?); + } + + Ok(Self { + cs, + limbs, + num_of_additions_over_normal_form: BaseField::zero(), + is_in_the_normal_form: true, + target_phantom: PhantomData, + }) + } + + /// Obtain the nonnative field element of one + pub fn one(cs: ConstraintSystemRef) -> R1CSResult { + Self::constant(cs, TargetField::one()) + } + + /// Obtain the nonnative field element of zero + pub fn zero(cs: ConstraintSystemRef) -> R1CSResult { + Self::constant(cs, TargetField::zero()) + } + + /// Add a nonnative field element + #[tracing::instrument(target = "r1cs")] + pub fn add(&self, other: &Self) -> R1CSResult { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let mut limbs = Vec::new(); + for (this_limb, other_limb) in self.limbs.iter().zip(other.limbs.iter()) { + limbs.push(this_limb + other_limb); + } + + let mut res = Self { + cs: self.cs(), + limbs, + num_of_additions_over_normal_form: self + .num_of_additions_over_normal_form + .add(&other.num_of_additions_over_normal_form) + .add(&BaseField::one()), + is_in_the_normal_form: false, + target_phantom: PhantomData, + }; + + Reducer::::post_add_reduce(&mut res)?; + Ok(res) + } + + /// Add a constant + #[tracing::instrument(target = "r1cs")] + pub fn add_constant(&self, other: &TargetField) -> R1CSResult { + let other_limbs = Self::get_limbs_representations(other, self.get_optimization_type())?; + + let mut limbs = Vec::new(); + for (this_limb, other_limb) in self.limbs.iter().zip(other_limbs.iter()) { + limbs.push(this_limb + *other_limb); + } + + let mut res = Self { + cs: self.cs(), + limbs, + num_of_additions_over_normal_form: self + .num_of_additions_over_normal_form + .add(&BaseField::one()), + is_in_the_normal_form: false, + target_phantom: PhantomData, + }; + + Reducer::::post_add_reduce(&mut res)?; + + Ok(res) + } + + /// Subtract a nonnative field element, without the final reduction step + #[tracing::instrument(target = "r1cs")] + pub fn sub_without_reduce(&self, other: &Self) -> R1CSResult { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + // Step 1: reduce the `other` if needed + let mut surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::one()) + 1; + let mut other = other.clone(); + if (surfeit + params.bits_per_limb > BaseField::size_in_bits() - 1) + || (surfeit + + (TargetField::size_in_bits() - params.bits_per_limb * (params.num_limbs - 1)) + > BaseField::size_in_bits() - 1) + { + Reducer::reduce(&mut other)?; + surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::one()) + 1; + } + + // Step 2: construct the padding + let mut pad_non_top_limb_repr: ::BigInt = + BaseField::one().into_repr(); + let mut pad_top_limb_repr: ::BigInt = pad_non_top_limb_repr; + + pad_non_top_limb_repr.muln((surfeit + params.bits_per_limb) as u32); + let pad_non_top_limb = BaseField::from_repr(pad_non_top_limb_repr).unwrap(); + + pad_top_limb_repr.muln( + (surfeit + + (TargetField::size_in_bits() - params.bits_per_limb * (params.num_limbs - 1))) + as u32, + ); + let pad_top_limb = BaseField::from_repr(pad_top_limb_repr).unwrap(); + + let mut pad_limbs = Vec::new(); + pad_limbs.push(pad_top_limb); + for _ in 0..self.limbs.len() - 1 { + pad_limbs.push(pad_non_top_limb); + } + + // Step 3: prepare to pad the padding to k * p for some k + let pad_to_kp_gap = Self::limbs_to_value(pad_limbs, self.get_optimization_type()).neg(); + let pad_to_kp_limbs = + Self::get_limbs_representations(&pad_to_kp_gap, self.get_optimization_type())?; + + // Step 4: the result is self + pad + pad_to_kp - other + let mut limbs = Vec::new(); + for (i, ((this_limb, other_limb), pad_to_kp_limb)) in self + .limbs + .iter() + .zip(other.limbs.iter()) + .zip(pad_to_kp_limbs.iter()) + .enumerate() + { + if i != 0 { + limbs.push(this_limb + pad_non_top_limb + *pad_to_kp_limb - other_limb); + } else { + limbs.push(this_limb + pad_top_limb + *pad_to_kp_limb - other_limb); + } + } + + let result = AllocatedNonNativeFieldVar:: { + cs: self.cs(), + limbs, + num_of_additions_over_normal_form: self.num_of_additions_over_normal_form + + (other.num_of_additions_over_normal_form + BaseField::one()) + + (other.num_of_additions_over_normal_form + BaseField::one()), + is_in_the_normal_form: false, + target_phantom: PhantomData, + }; + + Ok(result) + } + + /// Subtract a nonnative field element + #[tracing::instrument(target = "r1cs")] + pub fn sub(&self, other: &Self) -> R1CSResult { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let mut result = self.sub_without_reduce(other)?; + Reducer::::post_add_reduce(&mut result)?; + Ok(result) + } + + /// Subtract a constant + #[tracing::instrument(target = "r1cs")] + pub fn sub_constant(&self, other: &TargetField) -> R1CSResult { + self.sub(&Self::constant(self.cs(), *other)?) + } + + /// Multiply a nonnative field element + #[tracing::instrument(target = "r1cs")] + pub fn mul(&self, other: &Self) -> R1CSResult { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + self.mul_without_reduce(&other)?.reduce() + } + + /// Multiply a constant + pub fn mul_constant(&self, other: &TargetField) -> R1CSResult { + self.mul(&Self::constant(self.cs(), *other)?) + } + + /// Compute the negate of a nonnative field element + #[tracing::instrument(target = "r1cs")] + pub fn negate(&self) -> R1CSResult { + Self::zero(self.cs())?.sub(self) + } + + /// Compute the inverse of a nonnative field element + #[tracing::instrument(target = "r1cs")] + pub fn inverse(&self) -> R1CSResult { + let inverse = Self::new_witness(self.cs(), || { + Ok(self.value()?.inverse().unwrap_or_else(TargetField::zero)) + })?; + + let actual_result = self.clone().mul(&inverse)?; + actual_result.conditional_enforce_equal(&Self::one(self.cs())?, &Boolean::TRUE)?; + Ok(inverse) + } + + /// Convert a `TargetField` element into limbs (not constraints) + /// This is an internal function that would be reused by a number of other functions + pub fn get_limbs_representations( + elem: &TargetField, + optimization_type: OptimizationType, + ) -> R1CSResult> { + Self::get_limbs_representations_from_big_integer(&elem.into_repr(), optimization_type) + } + + /// Obtain the limbs directly from a big int + pub fn get_limbs_representations_from_big_integer( + elem: &::BigInt, + optimization_type: OptimizationType, + ) -> R1CSResult> { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + optimization_type, + ); + + // push the lower limbs first + let mut limbs: Vec = Vec::new(); + let mut cur = *elem; + for _ in 0..params.num_limbs { + let cur_bits = cur.to_bits_be(); // `to_bits` is big endian + let cur_mod_r = ::BigInt::from_bits_be( + &cur_bits[cur_bits.len() - params.bits_per_limb..], + ); // therefore, the lowest `bits_per_non_top_limb` bits is what we want. + limbs.push(BaseField::from_repr(cur_mod_r).unwrap()); + cur.divn(params.bits_per_limb as u32); + } + + // then we reserve, so that the limbs are ``big limb first'' + limbs.reverse(); + + Ok(limbs) + } + + /// for advanced use, multiply and output the intermediate representations (without reduction) + /// This intermediate representations can be added with each other, and they can later be reduced back to the `NonNativeFieldVar`. + #[tracing::instrument(target = "r1cs")] + pub fn mul_without_reduce( + &self, + other: &Self, + ) -> R1CSResult> { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + // Step 1: reduce `self` and `other` if neceessary + let mut self_reduced = self.clone(); + let mut other_reduced = other.clone(); + Reducer::::pre_mul_reduce(&mut self_reduced, &mut other_reduced)?; + + let mut prod_limbs = Vec::new(); + if self.get_optimization_type() == OptimizationType::Weight { + let zero = FpVar::::zero(); + + for _ in 0..2 * params.num_limbs - 1 { + prod_limbs.push(zero.clone()); + } + + for i in 0..params.num_limbs { + for j in 0..params.num_limbs { + prod_limbs[i + j] = + &prod_limbs[i + j] + (&self_reduced.limbs[i] * &other_reduced.limbs[j]); + } + } + } else { + let cs = self.cs().or(other.cs()); + + for z_index in 0..2 * params.num_limbs - 1 { + prod_limbs.push(FpVar::new_witness(ns!(cs, "limb product"), || { + let mut z_i = BaseField::zero(); + for i in 0..=min(params.num_limbs - 1, z_index) { + let j = z_index - i; + if j < params.num_limbs { + z_i += &self_reduced.limbs[i] + .value()? + .mul(&other_reduced.limbs[j].value()?); + } + } + + Ok(z_i) + })?); + } + + for c in 0..(2 * params.num_limbs - 1) { + let c_pows: Vec<_> = (0..(2 * params.num_limbs - 1)) + .map(|i| BaseField::from((c + 1) as u128).pow(&vec![i as u64])) + .collect(); + + let x = self_reduced + .limbs + .iter() + .zip(c_pows.iter()) + .map(|(var, c_pow)| var * *c_pow) + .fold(FpVar::zero(), |sum, i| sum + i); + + let y = other_reduced + .limbs + .iter() + .zip(c_pows.iter()) + .map(|(var, c_pow)| var * *c_pow) + .fold(FpVar::zero(), |sum, i| sum + i); + + let z = prod_limbs + .iter() + .zip(c_pows.iter()) + .map(|(var, c_pow)| var * *c_pow) + .fold(FpVar::zero(), |sum, i| sum + i); + + z.enforce_equal(&(x * y))?; + } + } + + Ok(AllocatedNonNativeFieldMulResultVar { + cs: self.cs(), + limbs: prod_limbs, + prod_of_num_of_additions: (self_reduced.num_of_additions_over_normal_form + + BaseField::one()) + * (other_reduced.num_of_additions_over_normal_form + BaseField::one()), + target_phantom: PhantomData, + }) + } + + pub(crate) fn frobenius_map(&self, _power: usize) -> R1CSResult { + Ok(self.clone()) + } + + pub(crate) fn conditional_enforce_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> R1CSResult<()> { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + // Get p + let p_representations = + AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( + &::Params::MODULUS, + self.get_optimization_type() + )?; + let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations); + + let mut p_gadget_limbs = Vec::new(); + for limb in p_representations.iter() { + p_gadget_limbs.push(FpVar::::Constant(*limb)); + } + let p_gadget = AllocatedNonNativeFieldVar:: { + cs: self.cs(), + limbs: p_gadget_limbs, + num_of_additions_over_normal_form: BaseField::one(), + is_in_the_normal_form: false, + target_phantom: PhantomData, + }; + + // Get delta = self - other + let cs = self.cs().or(other.cs()).or(should_enforce.cs()); + let mut delta = self.sub_without_reduce(other)?; + delta = should_enforce.select(&delta, &Self::zero(cs.clone())?)?; + + // Allocate k = delta / p + let k_gadget = FpVar::::new_witness(ns!(cs, "k"), || { + let mut delta_limbs_values = Vec::::new(); + for limb in delta.limbs.iter() { + delta_limbs_values.push(limb.value()?); + } + + let delta_bigint = limbs_to_bigint(params.bits_per_limb, &delta_limbs_values); + + Ok(bigint_to_basefield::(&(delta_bigint / p_bigint))) + })?; + + let surfeit = overhead!(delta.num_of_additions_over_normal_form + BaseField::one()) + 1; + Reducer::::limb_to_bits(&k_gadget, surfeit)?; + + // Compute k * p + let mut kp_gadget_limbs = Vec::new(); + for limb in p_gadget.limbs.iter() { + kp_gadget_limbs.push(limb * &k_gadget); + } + + // Enforce delta = kp + Reducer::::group_and_check_equality( + surfeit, + params.bits_per_limb, + params.bits_per_limb, + &delta.limbs, + &kp_gadget_limbs, + )?; + + Ok(()) + } + + #[tracing::instrument(target = "r1cs")] + pub(crate) fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> R1CSResult<()> { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let cs = self.cs().or(other.cs()).or(should_enforce.cs()); + + let _ = should_enforce + .select(&self.sub(other)?, &Self::one(cs)?)? + .inverse()?; + + Ok(()) + } + + pub(crate) fn get_optimization_type(&self) -> OptimizationType { + match self.cs().optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + } + } + + /// Allocates a new variable, but does not check that the allocation's limbs are in-range. + fn new_variable_unchecked>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> R1CSResult { + let ns = cs.into(); + let cs = ns.cs(); + + let optimization_type = match cs.optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + }; + + let zero = TargetField::zero(); + + let elem = match f() { + Ok(t) => *(t.borrow()), + Err(_) => zero, + }; + let elem_representations = Self::get_limbs_representations(&elem, optimization_type)?; + let mut limbs = Vec::new(); + + for limb in elem_representations.iter() { + limbs.push(FpVar::::new_variable( + ark_relations::ns!(cs, "alloc"), + || Ok(limb), + mode, + )?); + } + + let num_of_additions_over_normal_form = if mode != AllocationMode::Witness { + BaseField::zero() + } else { + BaseField::one() + }; + + Ok(Self { + cs, + limbs, + num_of_additions_over_normal_form, + is_in_the_normal_form: mode != AllocationMode::Witness, + target_phantom: PhantomData, + }) + } + + /// Check that this element is in-range; i.e., each limb is in-range, and the whole number is + /// less than the modulus. + /// + /// Returns the bits of the element, in little-endian form + fn enforce_in_range( + &self, + cs: impl Into>, + ) -> R1CSResult>> { + let ns = cs.into(); + let cs = ns.cs(); + let optimization_type = match cs.optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + }; + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + optimization_type, + ); + let mut bits = Vec::new(); + for limb in self.limbs.iter().rev().take(params.num_limbs - 1) { + bits.extend( + Reducer::::limb_to_bits(limb, params.bits_per_limb)? + .into_iter() + .rev(), + ); + } + + bits.extend( + Reducer::::limb_to_bits( + &self.limbs[0], + TargetField::size_in_bits() - (params.num_limbs - 1) * params.bits_per_limb, + )? + .into_iter() + .rev(), + ); + Ok(bits) + } + + /// Allocates a new non-native field witness with value given by the function `f`. Enforces + /// that the field element has value in `[0, modulus)`, and returns the bits of its binary + /// representation. The bits are in little-endian (i.e., the bit at index 0 is the LSB) and the + /// bit-vector is empty in non-witness allocation modes. + pub fn new_witness_with_le_bits>( + cs: impl Into>, + f: impl FnOnce() -> Result, + ) -> R1CSResult<(Self, Vec>)> { + let ns = cs.into(); + let cs = ns.cs(); + let this = Self::new_variable_unchecked(ns!(cs, "alloc"), f, AllocationMode::Witness)?; + let bits = this.enforce_in_range(ns!(cs, "bits"))?; + Ok((this, bits)) + } +} + +impl ToBitsGadget + for AllocatedNonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> R1CSResult>> { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + // Reduce to the normal form + // Though, a malicious prover can make it slightly larger than p + let mut self_normal = self.clone(); + Reducer::::pre_eq_reduce(&mut self_normal)?; + + // Therefore, we convert it to bits and enforce that it is in the field + let mut bits = Vec::>::new(); + for limb in self_normal.limbs.iter() { + bits.extend_from_slice(&Reducer::::limb_to_bits( + &limb, + params.bits_per_limb, + )?); + } + bits.reverse(); + + let mut b = TargetField::characteristic().to_vec(); + assert_eq!(b[0] % 2, 1); + b[0] -= 1; // This works, because the LSB is one, so there's no borrows. + let run = Boolean::::enforce_smaller_or_equal_than_le(&bits, b)?; + + // We should always end in a "run" of zeros, because + // the characteristic is an odd prime. So, this should + // be empty. + assert!(run.is_empty()); + + Ok(bits) + } +} + +impl ToBytesGadget + for AllocatedNonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> R1CSResult>> { + let mut bits = self.to_bits_le()?; + + let num_bits = TargetField::BigInt::NUM_LIMBS * 64; + assert!(bits.len() <= num_bits); + bits.resize_with(num_bits, || Boolean::constant(false)); + + let bytes = bits.chunks(8).map(UInt8::from_bits_le).collect(); + Ok(bytes) + } +} + +impl CondSelectGadget + for AllocatedNonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> R1CSResult { + assert_eq!( + true_value.get_optimization_type(), + false_value.get_optimization_type() + ); + + let mut limbs_sel = Vec::with_capacity(true_value.limbs.len()); + + for (x, y) in true_value.limbs.iter().zip(&false_value.limbs) { + limbs_sel.push(FpVar::::conditionally_select(cond, x, y)?); + } + + Ok(Self { + cs: true_value.cs().or(false_value.cs()), + limbs: limbs_sel, + num_of_additions_over_normal_form: max( + true_value.num_of_additions_over_normal_form, + false_value.num_of_additions_over_normal_form, + ), + is_in_the_normal_form: true_value.is_in_the_normal_form + && false_value.is_in_the_normal_form, + target_phantom: PhantomData, + }) + } +} + +impl TwoBitLookupGadget + for AllocatedNonNativeFieldVar +{ + type TableConstant = TargetField; + + #[tracing::instrument(target = "r1cs")] + fn two_bit_lookup( + bits: &[Boolean], + constants: &[Self::TableConstant], + ) -> R1CSResult { + debug_assert!(bits.len() == 2); + debug_assert!(constants.len() == 4); + + let cs = bits.cs(); + + let optimization_type = match cs.optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + }; + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + optimization_type, + ); + let mut limbs_constants = Vec::new(); + for _ in 0..params.num_limbs { + limbs_constants.push(Vec::new()); + } + + for constant in constants.iter() { + let representations = + AllocatedNonNativeFieldVar::::get_limbs_representations( + constant, + optimization_type, + )?; + + for (i, representation) in representations.iter().enumerate() { + limbs_constants[i].push(*representation); + } + } + + let mut limbs = Vec::new(); + for limbs_constant in limbs_constants.iter() { + limbs.push(FpVar::::two_bit_lookup(bits, limbs_constant)?); + } + + Ok(AllocatedNonNativeFieldVar:: { + cs, + limbs, + num_of_additions_over_normal_form: BaseField::zero(), + is_in_the_normal_form: true, + target_phantom: PhantomData, + }) + } +} + +impl ThreeBitCondNegLookupGadget + for AllocatedNonNativeFieldVar +{ + type TableConstant = TargetField; + + #[tracing::instrument(target = "r1cs")] + fn three_bit_cond_neg_lookup( + bits: &[Boolean], + b0b1: &Boolean, + constants: &[Self::TableConstant], + ) -> R1CSResult { + debug_assert!(bits.len() == 3); + debug_assert!(constants.len() == 4); + + let cs = bits.cs().or(b0b1.cs()); + + let optimization_type = match cs.optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + }; + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + optimization_type, + ); + + let mut limbs_constants = Vec::new(); + for _ in 0..params.num_limbs { + limbs_constants.push(Vec::new()); + } + + for constant in constants.iter() { + let representations = + AllocatedNonNativeFieldVar::::get_limbs_representations( + constant, + optimization_type, + )?; + + for (i, representation) in representations.iter().enumerate() { + limbs_constants[i].push(*representation); + } + } + + let mut limbs = Vec::new(); + for limbs_constant in limbs_constants.iter() { + limbs.push(FpVar::::three_bit_cond_neg_lookup( + bits, + b0b1, + limbs_constant, + )?); + } + + Ok(AllocatedNonNativeFieldVar:: { + cs, + limbs, + num_of_additions_over_normal_form: BaseField::zero(), + is_in_the_normal_form: true, + target_phantom: PhantomData, + }) + } +} + +impl AllocVar + for AllocatedNonNativeFieldVar +{ + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> R1CSResult { + let ns = cs.into(); + let cs = ns.cs(); + let this = Self::new_variable_unchecked(ns!(cs, "alloc"), f, mode)?; + if mode == AllocationMode::Witness { + this.enforce_in_range(ns!(cs, "bits"))?; + } + Ok(this) + } +} + +impl ToConstraintFieldGadget + for AllocatedNonNativeFieldVar +{ + fn to_constraint_field(&self) -> R1CSResult>> { + // provide a unique representation of the nonnative variable + // step 1: convert it into a bit sequence + let bits = self.to_bits_le()?; + + // step 2: obtain the parameters for weight-optimized (often, fewer limbs) + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + OptimizationType::Weight, + ); + + // step 3: assemble the limbs + let mut limbs = bits + .chunks(params.bits_per_limb) + .map(|chunk| { + let mut limb = FpVar::::zero(); + let mut w = BaseField::one(); + for b in chunk.iter() { + limb += FpVar::from(b.clone()) * w; + w.double_in_place(); + } + limb + }) + .collect::>>(); + + limbs.reverse(); + + // step 4: output the limbs + Ok(limbs) + } +} + +/* + * Implementation of a few traits + */ + +impl Clone + for AllocatedNonNativeFieldVar +{ + fn clone(&self) -> Self { + AllocatedNonNativeFieldVar { + cs: self.cs(), + limbs: self.limbs.clone(), + num_of_additions_over_normal_form: self.num_of_additions_over_normal_form, + is_in_the_normal_form: self.is_in_the_normal_form, + target_phantom: PhantomData, + } + } +} diff --git a/src/fields/nonnative/allocated_mul_result.rs b/src/fields/nonnative/allocated_mul_result.rs new file mode 100644 index 0000000..28bf1e8 --- /dev/null +++ b/src/fields/nonnative/allocated_mul_result.rs @@ -0,0 +1,289 @@ +use super::params::{get_params, OptimizationType}; +use super::reduce::{bigint_to_basefield, limbs_to_bigint, Reducer}; +use super::AllocatedNonNativeFieldVar; +use crate::fields::fp::FpVar; +use crate::prelude::*; +use ark_ff::{FpParameters, PrimeField}; +use ark_relations::r1cs::{OptimizationGoal, Result as R1CSResult}; +use ark_relations::{ns, r1cs::ConstraintSystemRef}; +use ark_std::marker::PhantomData; +use ark_std::vec::Vec; +use num_bigint::BigUint; + +/// The allocated form of `NonNativeFieldMulResultVar` (introduced below) +#[derive(Debug)] +#[must_use] +pub struct AllocatedNonNativeFieldMulResultVar { + /// Constraint system reference + pub cs: ConstraintSystemRef, + /// Limbs of the intermediate representations + pub limbs: Vec>, + /// The cumulative num of additions + pub prod_of_num_of_additions: BaseField, + #[doc(hidden)] + pub target_phantom: PhantomData, +} + +impl + From<&AllocatedNonNativeFieldVar> + for AllocatedNonNativeFieldMulResultVar +{ + fn from(src: &AllocatedNonNativeFieldVar) -> Self { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + src.get_optimization_type(), + ); + + let mut limbs = src.limbs.clone(); + limbs.reverse(); + limbs.resize(2 * params.num_limbs - 1, FpVar::::zero()); + limbs.reverse(); + + let prod_of_num_of_additions = src.num_of_additions_over_normal_form + &BaseField::one(); + + Self { + cs: src.cs(), + limbs, + prod_of_num_of_additions, + target_phantom: PhantomData, + } + } +} + +impl + AllocatedNonNativeFieldMulResultVar +{ + /// Get the CS + pub fn cs(&self) -> ConstraintSystemRef { + self.cs.clone() + } + + /// Get the value of the multiplication result + pub fn value(&self) -> R1CSResult { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + let p_representations = + AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( + &::Params::MODULUS, + self.get_optimization_type() + )?; + let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations); + + let mut limbs_values = Vec::::new(); + for limb in self.limbs.iter() { + limbs_values.push(limb.value().unwrap_or_default()); + } + let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values); + + let res = bigint_to_basefield::(&(value_bigint % p_bigint)); + Ok(res) + } + + /// Constraints for reducing the result of a multiplication mod p, to get an original representation. + pub fn reduce(&self) -> R1CSResult> { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + // Step 1: get p + let p_representations = + AllocatedNonNativeFieldVar::::get_limbs_representations_from_big_integer( + &::Params::MODULUS, + self.get_optimization_type() + )?; + let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations); + + let mut p_gadget_limbs = Vec::new(); + for limb in p_representations.iter() { + p_gadget_limbs.push(FpVar::::new_constant(self.cs(), limb)?); + } + let p_gadget = AllocatedNonNativeFieldVar:: { + cs: self.cs(), + limbs: p_gadget_limbs, + num_of_additions_over_normal_form: BaseField::one(), + is_in_the_normal_form: false, + target_phantom: PhantomData, + }; + + // Step 2: compute surfeit + let surfeit = overhead!(self.prod_of_num_of_additions + BaseField::one()) + 1 + 1; + + // Step 3: allocate k + let k_bits = { + let mut res = Vec::new(); + + let mut limbs_values = Vec::::new(); + for limb in self.limbs.iter() { + limbs_values.push(limb.value().unwrap_or_default()); + } + + let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values); + let mut k_cur = value_bigint / p_bigint; + + let total_len = TargetField::size_in_bits() + surfeit; + + for _ in 0..total_len { + res.push(Boolean::::new_witness(self.cs(), || { + Ok(&k_cur % 2u64 == BigUint::from(1u64)) + })?); + k_cur /= 2u64; + } + res + }; + + let k_limbs = { + let zero = FpVar::Constant(BaseField::zero()); + let mut limbs = Vec::new(); + + let mut k_bits_cur = k_bits.clone(); + + for i in 0..params.num_limbs { + let this_limb_size = if i != params.num_limbs - 1 { + params.bits_per_limb + } else { + k_bits.len() - (params.num_limbs - 1) * params.bits_per_limb + }; + + let this_limb_bits = k_bits_cur[0..this_limb_size].to_vec(); + k_bits_cur = k_bits_cur[this_limb_size..].to_vec(); + + let mut limb = zero.clone(); + let mut cur = BaseField::one(); + + for bit in this_limb_bits.iter() { + limb += &(FpVar::::from(bit.clone()) * cur); + cur.double_in_place(); + } + limbs.push(limb); + } + + limbs.reverse(); + limbs + }; + + let k_gadget = AllocatedNonNativeFieldVar:: { + cs: self.cs(), + limbs: k_limbs, + num_of_additions_over_normal_form: self.prod_of_num_of_additions, + is_in_the_normal_form: false, + target_phantom: PhantomData, + }; + + let cs = self.cs(); + + let r_gadget = AllocatedNonNativeFieldVar::::new_witness( + ns!(cs, "r"), + || Ok(self.value()?), + )?; + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + self.get_optimization_type(), + ); + + // Step 1: reduce `self` and `other` if neceessary + let mut prod_limbs = Vec::new(); + let zero = FpVar::::zero(); + + for _ in 0..2 * params.num_limbs - 1 { + prod_limbs.push(zero.clone()); + } + + for i in 0..params.num_limbs { + for j in 0..params.num_limbs { + prod_limbs[i + j] = &prod_limbs[i + j] + (&p_gadget.limbs[i] * &k_gadget.limbs[j]); + } + } + + let mut kp_plus_r_gadget = Self { + cs, + limbs: prod_limbs, + prod_of_num_of_additions: (p_gadget.num_of_additions_over_normal_form + + BaseField::one()) + * (k_gadget.num_of_additions_over_normal_form + BaseField::one()), + target_phantom: PhantomData, + }; + + let kp_plus_r_limbs_len = kp_plus_r_gadget.limbs.len(); + for (i, limb) in r_gadget.limbs.iter().rev().enumerate() { + kp_plus_r_gadget.limbs[kp_plus_r_limbs_len - 1 - i] += limb; + } + + Reducer::::group_and_check_equality( + surfeit, + 2 * params.bits_per_limb, + params.bits_per_limb, + &self.limbs, + &kp_plus_r_gadget.limbs, + )?; + + Ok(r_gadget) + } + + /// Add unreduced elements. + #[tracing::instrument(target = "r1cs")] + pub fn add(&self, other: &Self) -> R1CSResult { + assert_eq!(self.get_optimization_type(), other.get_optimization_type()); + + let mut new_limbs = Vec::new(); + + for (l1, l2) in self.limbs.iter().zip(other.limbs.iter()) { + let new_limb = l1 + l2; + new_limbs.push(new_limb); + } + + Ok(Self { + cs: self.cs(), + limbs: new_limbs, + prod_of_num_of_additions: self.prod_of_num_of_additions + + other.prod_of_num_of_additions, + target_phantom: PhantomData, + }) + } + + /// Add native constant elem + #[tracing::instrument(target = "r1cs")] + pub fn add_constant(&self, other: &TargetField) -> R1CSResult { + let mut other_limbs = + AllocatedNonNativeFieldVar::::get_limbs_representations( + other, + self.get_optimization_type(), + )?; + other_limbs.reverse(); + + let mut new_limbs = Vec::new(); + + for (i, limb) in self.limbs.iter().rev().enumerate() { + if i < other_limbs.len() { + new_limbs.push(limb + other_limbs[i]); + } else { + new_limbs.push((*limb).clone()); + } + } + + new_limbs.reverse(); + + Ok(Self { + cs: self.cs(), + limbs: new_limbs, + prod_of_num_of_additions: self.prod_of_num_of_additions + BaseField::one(), + target_phantom: PhantomData, + }) + } + + pub(crate) fn get_optimization_type(&self) -> OptimizationType { + match self.cs().optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + } + } +} diff --git a/src/fields/nonnative/field_var.rs b/src/fields/nonnative/field_var.rs new file mode 100644 index 0000000..eb7ccf1 --- /dev/null +++ b/src/fields/nonnative/field_var.rs @@ -0,0 +1,494 @@ +use super::params::OptimizationType; +use super::{AllocatedNonNativeFieldVar, NonNativeFieldMulResultVar}; +use crate::boolean::Boolean; +use crate::fields::fp::FpVar; +use crate::fields::FieldVar; +use crate::prelude::*; +use crate::{R1CSVar, ToConstraintFieldGadget}; +use ark_ff::PrimeField; +use ark_ff::{to_bytes, FpParameters}; +use ark_relations::r1cs::Result as R1CSResult; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; +use ark_std::hash::{Hash, Hasher}; +use ark_std::{borrow::Borrow, vec::Vec}; + +/// A gadget for representing non-native (`TargetField`) field elements over the constraint field (`BaseField`). +#[derive(Clone, Debug)] +#[must_use] +pub enum NonNativeFieldVar { + /// Constant + Constant(TargetField), + /// Allocated gadget + Var(AllocatedNonNativeFieldVar), +} + +impl PartialEq + for NonNativeFieldVar +{ + fn eq(&self, other: &Self) -> bool { + self.value() + .unwrap_or_default() + .eq(&other.value().unwrap_or_default()) + } +} + +impl Eq + for NonNativeFieldVar +{ +} + +impl Hash + for NonNativeFieldVar +{ + fn hash(&self, state: &mut H) { + self.value().unwrap_or_default().hash(state); + } +} + +impl R1CSVar + for NonNativeFieldVar +{ + type Value = TargetField; + + fn cs(&self) -> ConstraintSystemRef { + match self { + Self::Constant(_) => ConstraintSystemRef::None, + Self::Var(a) => a.cs(), + } + } + + fn value(&self) -> R1CSResult { + match self { + Self::Constant(v) => Ok(*v), + Self::Var(v) => v.value(), + } + } +} + +impl From> + for NonNativeFieldVar +{ + fn from(other: Boolean) -> Self { + if let Boolean::Constant(b) = other { + Self::Constant(>::from(b as u128)) + } else { + // `other` is a variable + let one = Self::Constant(TargetField::one()); + let zero = Self::Constant(TargetField::zero()); + Self::conditionally_select(&other, &one, &zero).unwrap() + } + } +} + +impl + From> + for NonNativeFieldVar +{ + fn from(other: AllocatedNonNativeFieldVar) -> Self { + Self::Var(other) + } +} + +impl<'a, TargetField: PrimeField, BaseField: PrimeField> FieldOpsBounds<'a, TargetField, Self> + for NonNativeFieldVar +{ +} + +impl<'a, TargetField: PrimeField, BaseField: PrimeField> + FieldOpsBounds<'a, TargetField, NonNativeFieldVar> + for &'a NonNativeFieldVar +{ +} + +impl FieldVar + for NonNativeFieldVar +{ + fn zero() -> Self { + Self::Constant(TargetField::zero()) + } + + fn one() -> Self { + Self::Constant(TargetField::one()) + } + + fn constant(v: TargetField) -> Self { + Self::Constant(v) + } + + #[tracing::instrument(target = "r1cs")] + fn negate(&self) -> R1CSResult { + match self { + Self::Constant(c) => Ok(Self::Constant(-*c)), + Self::Var(v) => Ok(Self::Var(v.negate()?)), + } + } + + #[tracing::instrument(target = "r1cs")] + fn inverse(&self) -> R1CSResult { + match self { + Self::Constant(c) => Ok(Self::Constant(c.inverse().unwrap_or_default())), + Self::Var(v) => Ok(Self::Var(v.inverse()?)), + } + } + + #[tracing::instrument(target = "r1cs")] + fn frobenius_map(&self, power: usize) -> R1CSResult { + match self { + Self::Constant(c) => Ok(Self::Constant({ + let mut tmp = *c; + tmp.frobenius_map(power); + tmp + })), + Self::Var(v) => Ok(Self::Var(v.frobenius_map(power)?)), + } + } +} + +/****************************************************************************/ +/****************************************************************************/ + +impl_bounded_ops!( + NonNativeFieldVar, + TargetField, + Add, + add, + AddAssign, + add_assign, + |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { + use NonNativeFieldVar::*; + match (this, other) { + (Constant(c1), Constant(c2)) => Constant(*c1 + c2), + (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()), + (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()), + } + }, + |this: &'a NonNativeFieldVar, other: TargetField| { this + &NonNativeFieldVar::Constant(other) }, + (TargetField: PrimeField, BaseField: PrimeField), +); + +impl_bounded_ops!( + NonNativeFieldVar, + TargetField, + Sub, + sub, + SubAssign, + sub_assign, + |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { + use NonNativeFieldVar::*; + match (this, other) { + (Constant(c1), Constant(c2)) => Constant(*c1 - c2), + (Var(v), Constant(c)) => Var(v.sub_constant(c).unwrap()), + (Constant(c), Var(v)) => Var(v.sub_constant(c).unwrap().negate().unwrap()), + (Var(v1), Var(v2)) => Var(v1.sub(v2).unwrap()), + } + }, + |this: &'a NonNativeFieldVar, other: TargetField| { + this - &NonNativeFieldVar::Constant(other) + }, + (TargetField: PrimeField, BaseField: PrimeField), +); + +impl_bounded_ops!( + NonNativeFieldVar, + TargetField, + Mul, + mul, + MulAssign, + mul_assign, + |this: &'a NonNativeFieldVar, other: &'a NonNativeFieldVar| { + use NonNativeFieldVar::*; + match (this, other) { + (Constant(c1), Constant(c2)) => Constant(*c1 * c2), + (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.mul_constant(c).unwrap()), + (Var(v1), Var(v2)) => Var(v1.mul(v2).unwrap()), + } + }, + |this: &'a NonNativeFieldVar, other: TargetField| { + if other.is_zero() { + NonNativeFieldVar::zero() + } else { + this * &NonNativeFieldVar::Constant(other) + } + }, + (TargetField: PrimeField, BaseField: PrimeField), +); + +/****************************************************************************/ +/****************************************************************************/ + +impl EqGadget + for NonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn is_eq(&self, other: &Self) -> R1CSResult> { + let cs = self.cs().or(other.cs()); + + if cs == ConstraintSystemRef::None { + Ok(Boolean::Constant(self.value()? == other.value()?)) + } else { + let should_enforce_equal = + Boolean::new_witness(cs, || Ok(self.value()? == other.value()?))?; + + self.conditional_enforce_equal(other, &should_enforce_equal)?; + self.conditional_enforce_not_equal(other, &should_enforce_equal.not())?; + + Ok(should_enforce_equal) + } + } + + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> R1CSResult<()> { + match (self, other) { + (Self::Constant(c1), Self::Constant(c2)) => { + if c1 != c2 { + should_enforce.enforce_equal(&Boolean::FALSE)?; + } + Ok(()) + } + (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => { + let cs = v.cs(); + let c = AllocatedNonNativeFieldVar::new_constant(cs, c)?; + c.conditional_enforce_equal(v, should_enforce) + } + (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_equal(v2, should_enforce), + } + } + + #[tracing::instrument(target = "r1cs")] + fn conditional_enforce_not_equal( + &self, + other: &Self, + should_enforce: &Boolean, + ) -> R1CSResult<()> { + match (self, other) { + (Self::Constant(c1), Self::Constant(c2)) => { + if c1 == c2 { + should_enforce.enforce_equal(&Boolean::FALSE)?; + } + Ok(()) + } + (Self::Constant(c), Self::Var(v)) | (Self::Var(v), Self::Constant(c)) => { + let cs = v.cs(); + let c = AllocatedNonNativeFieldVar::new_constant(cs, c)?; + c.conditional_enforce_not_equal(v, should_enforce) + } + (Self::Var(v1), Self::Var(v2)) => v1.conditional_enforce_not_equal(v2, should_enforce), + } + } +} + +impl ToBitsGadget + for NonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn to_bits_le(&self) -> R1CSResult>> { + match self { + Self::Constant(_) => self.to_non_unique_bits_le(), + Self::Var(v) => v.to_bits_le(), + } + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bits_le(&self) -> R1CSResult>> { + use ark_ff::BitIteratorLE; + match self { + Self::Constant(c) => Ok(BitIteratorLE::new(&c.into_repr()) + .take((TargetField::Params::MODULUS_BITS) as usize) + .map(Boolean::constant) + .collect::>()), + Self::Var(v) => v.to_non_unique_bits_le(), + } + } +} + +impl ToBytesGadget + for NonNativeFieldVar +{ + /// Outputs the unique byte decomposition of `self` in *little-endian* + /// form. + #[tracing::instrument(target = "r1cs")] + fn to_bytes(&self) -> R1CSResult>> { + match self { + Self::Constant(c) => Ok(UInt8::constant_vec(&to_bytes![c].unwrap())), + Self::Var(v) => v.to_bytes(), + } + } + + #[tracing::instrument(target = "r1cs")] + fn to_non_unique_bytes(&self) -> R1CSResult>> { + match self { + Self::Constant(c) => Ok(UInt8::constant_vec(&to_bytes![c].unwrap())), + Self::Var(v) => v.to_non_unique_bytes(), + } + } +} + +impl CondSelectGadget + for NonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn conditionally_select( + cond: &Boolean, + true_value: &Self, + false_value: &Self, + ) -> R1CSResult { + match cond { + Boolean::Constant(true) => Ok(true_value.clone()), + Boolean::Constant(false) => Ok(false_value.clone()), + _ => { + let cs = cond.cs(); + let true_value = match true_value { + Self::Constant(f) => AllocatedNonNativeFieldVar::new_constant(cs.clone(), f)?, + Self::Var(v) => v.clone(), + }; + let false_value = match false_value { + Self::Constant(f) => AllocatedNonNativeFieldVar::new_constant(cs, f)?, + Self::Var(v) => v.clone(), + }; + cond.select(&true_value, &false_value).map(Self::Var) + } + } + } +} + +/// Uses two bits to perform a lookup into a table +/// `b` is little-endian: `b[0]` is LSB. +impl TwoBitLookupGadget + for NonNativeFieldVar +{ + type TableConstant = TargetField; + + #[tracing::instrument(target = "r1cs")] + fn two_bit_lookup(b: &[Boolean], c: &[Self::TableConstant]) -> R1CSResult { + debug_assert_eq!(b.len(), 2); + debug_assert_eq!(c.len(), 4); + if b.cs().is_none() { + // We're in the constant case + + let lsb = b[0].value()? as usize; + let msb = b[1].value()? as usize; + let index = lsb + (msb << 1); + Ok(Self::Constant(c[index])) + } else { + AllocatedNonNativeFieldVar::two_bit_lookup(b, c).map(Self::Var) + } + } +} + +impl ThreeBitCondNegLookupGadget + for NonNativeFieldVar +{ + type TableConstant = TargetField; + + #[tracing::instrument(target = "r1cs")] + fn three_bit_cond_neg_lookup( + b: &[Boolean], + b0b1: &Boolean, + c: &[Self::TableConstant], + ) -> R1CSResult { + debug_assert_eq!(b.len(), 3); + debug_assert_eq!(c.len(), 4); + + if b.cs().or(b0b1.cs()).is_none() { + // We're in the constant case + + let lsb = b[0].value()? as usize; + let msb = b[1].value()? as usize; + let index = lsb + (msb << 1); + let intermediate = c[index]; + + let is_negative = b[2].value()?; + let y = if is_negative { + -intermediate + } else { + intermediate + }; + Ok(Self::Constant(y)) + } else { + AllocatedNonNativeFieldVar::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var) + } + } +} + +impl AllocVar + for NonNativeFieldVar +{ + fn new_variable>( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> R1CSResult { + let ns = cs.into(); + let cs = ns.cs(); + + if cs == ConstraintSystemRef::None || mode == AllocationMode::Constant { + Ok(Self::Constant(*f()?.borrow())) + } else { + AllocatedNonNativeFieldVar::new_variable(cs, f, mode).map(Self::Var) + } + } +} + +impl ToConstraintFieldGadget + for NonNativeFieldVar +{ + #[tracing::instrument(target = "r1cs")] + fn to_constraint_field(&self) -> R1CSResult>> { + // Use one group element to represent the optimization type. + // + // By default, the constant is converted in the weight-optimized type, because it results in fewer elements. + match self { + Self::Constant(c) => Ok(AllocatedNonNativeFieldVar::get_limbs_representations( + c, + OptimizationType::Weight, + )? + .into_iter() + .map(FpVar::constant) + .collect()), + Self::Var(v) => v.to_constraint_field(), + } + } +} + +impl NonNativeFieldVar { + /// The `mul_without_reduce` for `NonNativeFieldVar` + #[tracing::instrument(target = "r1cs")] + pub fn mul_without_reduce( + &self, + other: &Self, + ) -> R1CSResult> { + match self { + Self::Constant(c) => match other { + Self::Constant(other_c) => Ok(NonNativeFieldMulResultVar::Constant(*c * other_c)), + Self::Var(other_v) => { + let self_v = + AllocatedNonNativeFieldVar::::new_constant( + self.cs(), + c, + )?; + Ok(NonNativeFieldMulResultVar::Var( + other_v.mul_without_reduce(&self_v)?, + )) + } + }, + Self::Var(v) => { + let other_v = match other { + Self::Constant(other_c) => { + AllocatedNonNativeFieldVar::::new_constant( + self.cs(), + other_c, + )? + } + Self::Var(other_v) => other_v.clone(), + }; + Ok(NonNativeFieldMulResultVar::Var( + v.mul_without_reduce(&other_v)?, + )) + } + } + } +} diff --git a/src/fields/nonnative/mod.rs b/src/fields/nonnative/mod.rs new file mode 100644 index 0000000..45e2b90 --- /dev/null +++ b/src/fields/nonnative/mod.rs @@ -0,0 +1,180 @@ +//! +//! ## Overview +//! +//! This module implements a field gadget for a prime field `Fp` over another prime field `Fq` where `p != q`. +//! +//! When writing constraint systems for many cryptographic proofs, we are restricted to a native field (e.g., the scalar field of the pairing-friendly curve). +//! This can be inconvenient; for example, the recursive composition of proofs via cycles of curves requires the verifier to compute over a non-native field. +//! +//! The library makes it possible to write computations over a non-native field in the same way one would write computations over the native field. This naturally introduces additional overhead, which we minimize using a variety of optimizations. (Nevertheless, the overhead is still substantial, and native fields should be used where possible.) +//! +//! ## Usage +//! +//! Because [`NonNativeFieldVar`] implements the [`FieldVar`] trait in arkworks, we can treat it like a native field variable ([`FpVar`]). +//! +//! We can do the standard field operations, such as `+`, `-`, and `*`. See the following example: +//! +//! ```rust +//! # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> { +//! # use ark_std::UniformRand; +//! # use ark_relations::{ns, r1cs::ConstraintSystem}; +//! # use ark_r1cs_std::prelude::*; +//! use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; +//! use ark_bls12_377::{Fr, Fq}; +//! +//! # let mut rng = ark_std::test_rng(); +//! # let a_value = Fr::rand(&mut rng); +//! # let b_value = Fr::rand(&mut rng); +//! # let cs = ConstraintSystem::::new_ref(); +//! +//! let a = NonNativeFieldVar::::new_witness(ns!(cs, "a"), || Ok(a_value))?; +//! let b = NonNativeFieldVar::::new_witness(ns!(cs, "b"), || Ok(b_value))?; +//! +//! // add +//! let a_plus_b = &a + &b; +//! +//! // sub +//! let a_minus_b = &a - &b; +//! +//! // multiply +//! let a_times_b = &a * &b; +//! +//! // enforce equality +//! a.enforce_equal(&b)?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Advanced optimization +//! +//! After each multiplication, our library internally performs a *reduce* operation, +//! which reduces an intermediate type [`NonNativeFieldMulResultVar`] to the normalized type [`NonNativeFieldVar`]. +//! This enables a user to seamlessly perform a sequence of operations without worrying about the underlying details. +//! +//! However, this operation is expensive and is sometimes avoidable. We can reduce the number of constraints by using this intermediate type, which only supports additions. To multiply, it must be reduced back to [`NonNativeFieldVar`]. See below for a skeleton example. +//! +//! --- +//! +//! To compute `a * b + c * d`, the straightforward (but more expensive) implementation is as follows: +//! +//! ```ignore +//! let a_times_b = &a * &b; +//! let c_times_d = &c * &d; +//! let res = &a_times_b + &c_times_d; +//! ``` +//! +//! This performs two *reduce* operations in total, one for each multiplication. +//! +//! --- +//! +//! We can save one reduction by using the [`NonNativeFieldMulResultVar`], as follows: +//! +//! ```ignore +//! let a_times_b = a.mul_without_reduce(&b)?; +//! let c_times_d = c.mul_without_reduce(&d)?; +//! let res = (&a_times_b + &c_times_d)?.reduce()?; +//! ``` +//! +//! It performs only one *reduce* operation and is roughly 2x faster than the first implementation. +//! +//! ## Inspiration and basic design +//! +//! This implementation employs the standard idea of using multiple **limbs** to represent an element of the target field. For example, an element in the TargetField may be represented by three BaseField elements (i.e., the limbs). +//! +//! ```text +//! TargetField -> limb 1, limb 2, and limb 3 (each is a BaseField element) +//! ``` +//! +//! After some computation, the limbs become saturated and need to be **reduced**, in order to engage in more computation. +//! +//! We heavily use the optimization techniques in [\[KPS18\]](https://akosba.github.io/papers/xjsnark.pdf) and [\[OWWB20\]](https://eprint.iacr.org/2019/1494). +//! Both works have their own open-source libraries: +//! [xJsnark](https://github.com/akosba/xjsnark) and +//! [bellman-bignat](https://github.com/alex-ozdemir/bellman-bignat). +//! Compared with these, this module works with the `arkworks` ecosystem. +//! It also provides the option (based on an `optimization_goal` for the constraint system) to optimize +//! for constraint density instead of number of constraints, which improves efficiency in +//! proof systems like [Marlin](https://github.com/arkworks-rs/marlin). +//! +//! ## References +//! \[KPS18\]: A. E. Kosba, C. Papamanthou, and E. Shi. "xJsnark: a framework for efficient verifiable computation," in *Proceedings of the 39th Symposium on Security and Privacy*, ser. S&P ’18, 2018, pp. 944–961. +//! +//! \[OWWB20\]: A. Ozdemir, R. S. Wahby, B. Whitehat, and D. Boneh. "Scaling verifiable computation using efficient set accumulators," in *Proceedings of the 29th USENIX Security Symposium*, ser. Security ’20, 2020. +//! +//! [`NonNativeFieldVar`]: crate::fields::nonnative::NonNativeFieldVar +//! [`NonNativeFieldMulResultVar`]: crate::fields::nonnative::NonNativeFieldMulResultVar +//! [`FpVar`]: crate::fields::fp::FpVar + +#![allow( + clippy::redundant_closure_call, + clippy::enum_glob_use, + clippy::missing_errors_doc, + clippy::cast_possible_truncation, + clippy::unseparated_literal_suffix +)] + +use ark_std::fmt::Debug; + +/// Utilities for sampling parameters for non-native field gadgets +/// +/// - `BaseField`: the constraint field +/// - `TargetField`: the field being simulated +/// - `num_limbs`: how many limbs are used +/// - `bits_per_limb`: the size of the limbs +pub mod params; +/// How are non-native elements reduced? +pub(crate) mod reduce; + +/// a macro for computing ceil(log2(x)) for a field element x +macro_rules! overhead { + ($x:expr) => {{ + use ark_ff::BigInteger; + let num = $x; + let num_bits = num.into_repr().to_bits_be(); + let mut skipped_bits = 0; + for b in num_bits.iter() { + if *b == false { + skipped_bits += 1; + } else { + break; + } + } + + let mut is_power_of_2 = true; + for b in num_bits.iter().skip(skipped_bits + 1) { + if *b == true { + is_power_of_2 = false; + } + } + + if is_power_of_2 { + num_bits.len() - skipped_bits + } else { + num_bits.len() - skipped_bits + 1 + } + }}; +} + +pub(crate) use overhead; + +/// Parameters for a specific `NonNativeFieldVar` instantiation +#[derive(Clone, Debug)] +pub struct NonNativeFieldParams { + /// The number of limbs (`BaseField` elements) used to represent a `TargetField` element. Highest limb first. + pub num_limbs: usize, + + /// The number of bits of the limb + pub bits_per_limb: usize, +} + +mod allocated_field_var; +pub use allocated_field_var::*; + +mod allocated_mul_result; +pub use allocated_mul_result::*; + +mod field_var; +pub use field_var::*; + +mod mul_result; +pub use mul_result::*; diff --git a/src/fields/nonnative/mul_result.rs b/src/fields/nonnative/mul_result.rs new file mode 100644 index 0000000..b1eb58c --- /dev/null +++ b/src/fields/nonnative/mul_result.rs @@ -0,0 +1,78 @@ +use super::{AllocatedNonNativeFieldMulResultVar, NonNativeFieldVar}; +use ark_ff::PrimeField; +use ark_relations::r1cs::Result as R1CSResult; + +/// An intermediate representation especially for the result of a multiplication, containing more limbs. +/// It is intended for advanced usage to improve the efficiency. +/// +/// That is, instead of calling `mul`, one can call `mul_without_reduce` to +/// obtain this intermediate representation, which can still be added. +/// Then, one can call `reduce` to reduce it back to `NonNativeFieldVar`. +/// This may help cut the number of reduce operations. +#[derive(Debug)] +#[must_use] +pub enum NonNativeFieldMulResultVar { + /// as a constant + Constant(TargetField), + /// as an allocated gadget + Var(AllocatedNonNativeFieldMulResultVar), +} + +impl + NonNativeFieldMulResultVar +{ + /// Create a zero `NonNativeFieldMulResultVar` (used for additions) + pub fn zero() -> Self { + Self::Constant(TargetField::zero()) + } + + /// Create an `NonNativeFieldMulResultVar` from a constant + pub fn constant(v: TargetField) -> Self { + Self::Constant(v) + } + + /// Reduce the `NonNativeFieldMulResultVar` back to NonNativeFieldVar + #[tracing::instrument(target = "r1cs")] + pub fn reduce(&self) -> R1CSResult> { + match self { + Self::Constant(c) => Ok(NonNativeFieldVar::Constant(*c)), + Self::Var(v) => Ok(NonNativeFieldVar::Var(v.reduce()?)), + } + } +} + +impl + From<&NonNativeFieldVar> + for NonNativeFieldMulResultVar +{ + fn from(src: &NonNativeFieldVar) -> Self { + match src { + NonNativeFieldVar::Constant(c) => NonNativeFieldMulResultVar::Constant(*c), + NonNativeFieldVar::Var(v) => { + NonNativeFieldMulResultVar::Var(AllocatedNonNativeFieldMulResultVar::< + TargetField, + BaseField, + >::from(v)) + } + } + } +} + +impl_bounded_ops!( + NonNativeFieldMulResultVar, + TargetField, + Add, + add, + AddAssign, + add_assign, + |this: &'a NonNativeFieldMulResultVar, other: &'a NonNativeFieldMulResultVar| { + use NonNativeFieldMulResultVar::*; + match (this, other) { + (Constant(c1), Constant(c2)) => Constant(*c1 + c2), + (Constant(c), Var(v)) | (Var(v), Constant(c)) => Var(v.add_constant(c).unwrap()), + (Var(v1), Var(v2)) => Var(v1.add(v2).unwrap()), + } + }, + |this: &'a NonNativeFieldMulResultVar, other: TargetField| { this + &NonNativeFieldMulResultVar::Constant(other) }, + (TargetField: PrimeField, BaseField: PrimeField), +); diff --git a/src/fields/nonnative/params.rs b/src/fields/nonnative/params.rs new file mode 100644 index 0000000..6345547 --- /dev/null +++ b/src/fields/nonnative/params.rs @@ -0,0 +1,96 @@ +use super::NonNativeFieldParams; + +/// Obtain the parameters from a `ConstraintSystem`'s cache or generate a new one +#[must_use] +pub const fn get_params( + target_field_size: usize, + base_field_size: usize, + optimization_type: OptimizationType, +) -> NonNativeFieldParams { + let (num_of_limbs, limb_size) = + find_parameters(base_field_size, target_field_size, optimization_type); + NonNativeFieldParams { + num_limbs: num_of_limbs, + bits_per_limb: limb_size, + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +/// The type of optimization target for the parameters searching +pub enum OptimizationType { + /// Optimized for constraints + Constraints, + /// Optimized for weight + Weight, +} + +/// A function to search for parameters for nonnative field gadgets +pub const fn find_parameters( + base_field_prime_length: usize, + target_field_prime_bit_length: usize, + optimization_type: OptimizationType, +) -> (usize, usize) { + let mut found = false; + let mut min_cost = 0usize; + let mut min_cost_limb_size = 0usize; + let mut min_cost_num_of_limbs = 0usize; + + let surfeit = 10; + let mut max_limb_size = (base_field_prime_length - 1 - surfeit - 1) / 2 - 1; + if max_limb_size > target_field_prime_bit_length { + max_limb_size = target_field_prime_bit_length; + } + let mut limb_size = 1; + + while limb_size <= max_limb_size { + let num_of_limbs = (target_field_prime_bit_length + limb_size - 1) / limb_size; + + let group_size = + (base_field_prime_length - 1 - surfeit - 1 - 1 - limb_size + limb_size - 1) / limb_size; + let num_of_groups = (2 * num_of_limbs - 1 + group_size - 1) / group_size; + + let mut this_cost = 0; + + match optimization_type { + OptimizationType::Constraints => { + this_cost += 2 * num_of_limbs - 1; + } + OptimizationType::Weight => { + this_cost += 6 * num_of_limbs * num_of_limbs; + } + }; + + match optimization_type { + OptimizationType::Constraints => { + this_cost += target_field_prime_bit_length; // allocation of k + this_cost += target_field_prime_bit_length + num_of_limbs; // allocation of r + //this_cost += 2 * num_of_limbs - 1; // compute kp + this_cost += num_of_groups + (num_of_groups - 1) * (limb_size * 2 + surfeit) + 1; + // equality check + } + OptimizationType::Weight => { + this_cost += target_field_prime_bit_length * 3 + target_field_prime_bit_length; // allocation of k + this_cost += target_field_prime_bit_length * 3 + + target_field_prime_bit_length + + num_of_limbs; // allocation of r + this_cost += num_of_limbs * num_of_limbs + 2 * (2 * num_of_limbs - 1); // compute kp + this_cost += num_of_limbs + + num_of_groups + + 6 * num_of_groups + + (num_of_groups - 1) * (2 * limb_size + surfeit) * 4 + + 2; // equality check + } + }; + + if !found || this_cost < min_cost { + found = true; + min_cost = this_cost; + min_cost_limb_size = limb_size; + min_cost_num_of_limbs = num_of_limbs; + } + + limb_size += 1; + } + + (min_cost_num_of_limbs, min_cost_limb_size) +} diff --git a/src/fields/nonnative/reduce.rs b/src/fields/nonnative/reduce.rs new file mode 100644 index 0000000..c5d2089 --- /dev/null +++ b/src/fields/nonnative/reduce.rs @@ -0,0 +1,334 @@ +use super::overhead; +use super::params::get_params; +use super::AllocatedNonNativeFieldVar; +use crate::eq::EqGadget; +use crate::fields::fp::FpVar; +use crate::fields::FieldVar; +use crate::{alloc::AllocVar, boolean::Boolean, R1CSVar}; +use ark_ff::{biginteger::BigInteger, fields::FpParameters, BitIteratorBE, One, PrimeField, Zero}; +use ark_relations::{ + ns, + r1cs::{ConstraintSystemRef, Result as R1CSResult}, +}; +use ark_std::{cmp::min, marker::PhantomData, vec, vec::Vec}; +use num_bigint::BigUint; +use num_integer::Integer; + +pub fn limbs_to_bigint( + bits_per_limb: usize, + limbs: &[BaseField], +) -> BigUint { + let mut val = BigUint::zero(); + let mut big_cur = BigUint::one(); + let two = BigUint::from(2u32); + for limb in limbs.iter().rev() { + let limb_repr = limb.into_repr().to_bits_le(); + let mut small_cur = big_cur.clone(); + for limb_bit in limb_repr.iter() { + if *limb_bit { + val += &small_cur; + } + small_cur *= 2u32; + } + big_cur *= two.pow(bits_per_limb as u32); + } + + val +} + +pub fn bigint_to_basefield(bigint: &BigUint) -> BaseField { + let mut val = BaseField::zero(); + let mut cur = BaseField::one(); + let bytes = bigint.to_bytes_be(); + + let basefield_256 = BaseField::from_repr(::BigInt::from(256)).unwrap(); + + for byte in bytes.iter().rev() { + let bytes_basefield = BaseField::from(*byte as u128); + val += cur * bytes_basefield; + + cur *= &basefield_256; + } + + val +} + +/// the collections of methods for reducing the presentations +pub struct Reducer { + pub target_phantom: PhantomData, + pub base_phantom: PhantomData, +} + +impl Reducer { + /// convert limbs to bits (take at most `BaseField::size_in_bits() - 1` bits) + /// This implementation would be more efficient than the original `to_bits` + /// or `to_non_unique_bits` since we enforce that some bits are always zero. + #[tracing::instrument(target = "r1cs")] + pub fn limb_to_bits( + limb: &FpVar, + num_bits: usize, + ) -> R1CSResult>> { + let cs = limb.cs(); + + let num_bits = min(BaseField::size_in_bits() - 1, num_bits); + let mut bits_considered = Vec::with_capacity(num_bits); + let limb_value = limb.value().unwrap_or_default(); + + for b in BitIteratorBE::new(limb_value.into_repr()).skip( + <::Params as FpParameters>::REPR_SHAVE_BITS as usize + + (BaseField::size_in_bits() - num_bits), + ) { + bits_considered.push(b); + } + + if cs == ConstraintSystemRef::None { + let mut bits = vec![]; + for b in bits_considered { + bits.push(Boolean::::Constant(b)); + } + + Ok(bits) + } else { + let mut bits = vec![]; + for b in bits_considered { + bits.push(Boolean::::new_witness( + ark_relations::ns!(cs, "bit"), + || Ok(b), + )?); + } + + let mut bit_sum = FpVar::::zero(); + let mut coeff = BaseField::one(); + + for bit in bits.iter().rev() { + bit_sum += + as From>>::from((*bit).clone()) * coeff; + coeff.double_in_place(); + } + + bit_sum.enforce_equal(limb)?; + + Ok(bits) + } + } + + /// Reduction to the normal form + #[tracing::instrument(target = "r1cs")] + pub fn reduce(elem: &mut AllocatedNonNativeFieldVar) -> R1CSResult<()> { + let new_elem = + AllocatedNonNativeFieldVar::new_witness(ns!(elem.cs(), "normal_form"), || { + Ok(elem.value().unwrap_or_default()) + })?; + elem.conditional_enforce_equal(&new_elem, &Boolean::TRUE)?; + *elem = new_elem; + + Ok(()) + } + + /// Reduction to be enforced after additions + #[tracing::instrument(target = "r1cs")] + pub fn post_add_reduce( + elem: &mut AllocatedNonNativeFieldVar, + ) -> R1CSResult<()> { + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + elem.get_optimization_type(), + ); + let surfeit = overhead!(elem.num_of_additions_over_normal_form + BaseField::one()) + 1; + + if BaseField::size_in_bits() > 2 * params.bits_per_limb + surfeit + 1 { + Ok(()) + } else { + Self::reduce(elem) + } + } + + /// Reduction used before multiplication to reduce the representations in a way that allows efficient multiplication + #[tracing::instrument(target = "r1cs")] + pub fn pre_mul_reduce( + elem: &mut AllocatedNonNativeFieldVar, + elem_other: &mut AllocatedNonNativeFieldVar, + ) -> R1CSResult<()> { + assert_eq!( + elem.get_optimization_type(), + elem_other.get_optimization_type() + ); + + let params = get_params( + TargetField::size_in_bits(), + BaseField::size_in_bits(), + elem.get_optimization_type(), + ); + + if 2 * params.bits_per_limb + ark_std::log2(params.num_limbs) as usize + > BaseField::size_in_bits() - 1 + { + panic!("The current limb parameters do not support multiplication."); + } + + loop { + let prod_of_num_of_additions = (elem.num_of_additions_over_normal_form + + BaseField::one()) + * (elem_other.num_of_additions_over_normal_form + BaseField::one()); + let overhead_limb = overhead!(prod_of_num_of_additions.mul( + &BaseField::from_repr(::BigInt::from( + (params.num_limbs) as u64 + )) + .unwrap() + )); + let bits_per_mulresult_limb = 2 * (params.bits_per_limb + 1) + overhead_limb; + + if bits_per_mulresult_limb < BaseField::size_in_bits() { + break; + } + + if elem.num_of_additions_over_normal_form + >= elem_other.num_of_additions_over_normal_form + { + Self::reduce(elem)?; + } else { + Self::reduce(elem_other)?; + } + } + + Ok(()) + } + + /// Reduction to the normal form + #[tracing::instrument(target = "r1cs")] + pub fn pre_eq_reduce( + elem: &mut AllocatedNonNativeFieldVar, + ) -> R1CSResult<()> { + if elem.is_in_the_normal_form { + return Ok(()); + } + + Self::reduce(elem) + } + + /// Group and check equality + #[tracing::instrument(target = "r1cs")] + pub fn group_and_check_equality( + surfeit: usize, + bits_per_limb: usize, + shift_per_limb: usize, + left: &[FpVar], + right: &[FpVar], + ) -> R1CSResult<()> { + let cs = left.cs().or(right.cs()); + let zero = FpVar::::zero(); + + let mut limb_pairs = Vec::<(FpVar, FpVar)>::new(); + let num_limb_in_a_group = (BaseField::size_in_bits() + - 1 + - surfeit + - 1 + - 1 + - 1 + - (bits_per_limb - shift_per_limb)) + / shift_per_limb; + + let shift_array = { + let mut array = Vec::new(); + let mut cur = BaseField::one().into_repr(); + for _ in 0..num_limb_in_a_group { + array.push(BaseField::from_repr(cur).unwrap()); + cur.muln(shift_per_limb as u32); + } + + array + }; + + for (left_limb, right_limb) in left.iter().zip(right.iter()).rev() { + // note: the `rev` operation is here, so that the first limb (and the first groupped limb) will be the least significant limb. + limb_pairs.push((left_limb.clone(), right_limb.clone())); + } + + let mut groupped_limb_pairs = Vec::<(FpVar, FpVar, usize)>::new(); + + for limb_pairs_in_a_group in limb_pairs.chunks(num_limb_in_a_group) { + let mut left_total_limb = zero.clone(); + let mut right_total_limb = zero.clone(); + + for ((left_limb, right_limb), shift) in + limb_pairs_in_a_group.iter().zip(shift_array.iter()) + { + left_total_limb += &(left_limb * *shift); + right_total_limb += &(right_limb * *shift); + } + + groupped_limb_pairs.push(( + left_total_limb, + right_total_limb, + limb_pairs_in_a_group.len(), + )); + } + + // This part we mostly use the techniques in bellman-bignat + // The following code is adapted from https://github.com/alex-ozdemir/bellman-bignat/blob/master/src/mp/bignat.rs#L567 + let mut carry_in = zero; + let mut carry_in_value = BaseField::zero(); + let mut accumulated_extra = BigUint::zero(); + for (group_id, (left_total_limb, right_total_limb, num_limb_in_this_group)) in + groupped_limb_pairs.iter().enumerate() + { + let mut pad_limb_repr: ::BigInt = BaseField::one().into_repr(); + + pad_limb_repr.muln( + (surfeit + + (bits_per_limb - shift_per_limb) + + shift_per_limb * num_limb_in_this_group + + 1 + + 1) as u32, + ); + let pad_limb = BaseField::from_repr(pad_limb_repr).unwrap(); + + let left_total_limb_value = left_total_limb.value().unwrap_or_default(); + let right_total_limb_value = right_total_limb.value().unwrap_or_default(); + + let mut carry_value = + left_total_limb_value + carry_in_value + pad_limb - right_total_limb_value; + + let mut carry_repr = carry_value.into_repr(); + carry_repr.divn((shift_per_limb * num_limb_in_this_group) as u32); + + carry_value = BaseField::from_repr(carry_repr).unwrap(); + + let carry = FpVar::::new_witness(cs.clone(), || Ok(carry_value))?; + + accumulated_extra += limbs_to_bigint(bits_per_limb, &[pad_limb]); + + let (new_accumulated_extra, remainder) = accumulated_extra.div_rem( + &BigUint::from(2u64).pow((shift_per_limb * num_limb_in_this_group) as u32), + ); + let remainder_limb = bigint_to_basefield::(&remainder); + + // Now check + // left_total_limb + pad_limb + carry_in - right_total_limb + // = carry shift by (shift_per_limb * num_limb_in_this_group) + remainder + + let eqn_left = left_total_limb + pad_limb + &carry_in - right_total_limb; + + let eqn_right = &carry + * BaseField::from(2u64).pow(&[(shift_per_limb * num_limb_in_this_group) as u64]) + + remainder_limb; + + eqn_left.conditional_enforce_equal(&eqn_right, &Boolean::::TRUE)?; + + accumulated_extra = new_accumulated_extra; + carry_in = carry.clone(); + carry_in_value = carry_value; + + if group_id == groupped_limb_pairs.len() - 1 { + carry.enforce_equal(&FpVar::::Constant(bigint_to_basefield( + &accumulated_extra, + )))?; + } else { + Reducer::::limb_to_bits(&carry, surfeit + bits_per_limb)?; + } + } + + Ok(()) + } +} diff --git a/tests/arithmetic_tests.rs b/tests/arithmetic_tests.rs new file mode 100644 index 0000000..9fa75ad --- /dev/null +++ b/tests/arithmetic_tests.rs @@ -0,0 +1,713 @@ +use ark_bls12_381::Bls12_381; +use ark_ec::PairingEngine; +use ark_ff::{BigInteger, PrimeField}; +use ark_mnt4_298::MNT4_298; +use ark_mnt4_753::MNT4_753; +use ark_mnt6_298::MNT6_298; +use ark_mnt6_753::MNT6_753; + +use ark_r1cs_std::fields::nonnative::{AllocatedNonNativeFieldVar, NonNativeFieldVar}; +use ark_r1cs_std::{alloc::AllocVar, eq::EqGadget, fields::FieldVar, R1CSVar}; +use ark_relations::r1cs::{ConstraintSystem, ConstraintSystemRef}; +use ark_std::rand::RngCore; + +#[cfg(not(ci))] +const NUM_REPETITIONS: usize = 100; +#[cfg(ci)] +const NUM_REPETITIONS: usize = 1; + +#[cfg(not(ci))] +const TEST_COUNT: usize = 100; +#[cfg(ci)] +const TEST_COUNT: usize = 1; + +fn allocation_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let a_native = TargetField::rand(rng); + let a = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc a"), + || Ok(a_native), + ) + .unwrap(); + + let a_actual = a.value().unwrap(); + let a_expected = a_native; + assert!( + a_actual.eq(&a_expected), + "allocated value does not equal the expected value" + ); + + let (_a, a_bits) = + AllocatedNonNativeFieldVar::::new_witness_with_le_bits( + ark_relations::ns!(cs, "alloc a2"), + || Ok(a_native), + ) + .unwrap(); + + let a_bits_actual: Vec = a_bits.into_iter().map(|b| b.value().unwrap()).collect(); + let mut a_bits_expected = a_native.into_repr().to_bits_le(); + a_bits_expected.truncate(TargetField::size_in_bits()); + assert_eq!( + a_bits_actual, a_bits_expected, + "allocated bits does not equal the expected bits" + ); +} + +fn addition_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let a_native = TargetField::rand(rng); + let a = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc a"), + || Ok(a_native), + ) + .unwrap(); + + let b_native = TargetField::rand(rng); + let b = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc b"), + || Ok(b_native), + ) + .unwrap(); + + let a_plus_b = a + &b; + + let a_plus_b_actual = a_plus_b.value().unwrap(); + let a_plus_b_expected = a_native + &b_native; + assert!(a_plus_b_actual.eq(&a_plus_b_expected), "a + b failed"); +} + +fn multiplication_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let a_native = TargetField::rand(rng); + let a = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc a"), + || Ok(a_native), + ) + .unwrap(); + + let b_native = TargetField::rand(rng); + let b = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc b"), + || Ok(b_native), + ) + .unwrap(); + + let a_times_b = a * &b; + + let a_times_b_actual = a_times_b.value().unwrap(); + let a_times_b_expected = a_native * &b_native; + + assert!( + a_times_b_actual.eq(&a_times_b_expected), + "a_times_b = {:?}, a_times_b_actual = {:?}, a_times_b_expected = {:?}", + a_times_b, + a_times_b_actual.into_repr().as_ref(), + a_times_b_expected.into_repr().as_ref() + ); +} + +fn equality_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let a_native = TargetField::rand(rng); + let a = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc a"), + || Ok(a_native), + ) + .unwrap(); + + let b_native = TargetField::rand(rng); + let b = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc b"), + || Ok(b_native), + ) + .unwrap(); + + let a_times_b = a * &b; + + let a_times_b_expected = a_native * &b_native; + let a_times_b_expected_gadget = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc a * b"), + || Ok(a_times_b_expected), + ) + .unwrap(); + + a_times_b.enforce_equal(&a_times_b_expected_gadget).unwrap(); +} + +fn edge_cases_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let zero_native = TargetField::zero(); + let zero = NonNativeFieldVar::::zero(); + let one = NonNativeFieldVar::::one(); + + let a_native = TargetField::rand(rng); + let minus_a_native = TargetField::zero() - &a_native; + let a = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "alloc a"), + || Ok(a_native), + ) + .unwrap(); + + let a_plus_zero = &a + &zero; + let a_minus_zero = &a - &zero; + let zero_minus_a = &zero - &a; + let a_times_zero = &a * &zero; + + let zero_plus_a = &zero + &a; + let zero_times_a = &zero * &a; + + let a_times_one = &a * &one; + let one_times_a = &one * &a; + + let a_plus_zero_native = a_plus_zero.value().unwrap(); + let a_minus_zero_native = a_minus_zero.value().unwrap(); + let zero_minus_a_native = zero_minus_a.value().unwrap(); + let a_times_zero_native = a_times_zero.value().unwrap(); + let zero_plus_a_native = zero_plus_a.value().unwrap(); + let zero_times_a_native = zero_times_a.value().unwrap(); + let a_times_one_native = a_times_one.value().unwrap(); + let one_times_a_native = one_times_a.value().unwrap(); + + assert!( + a_plus_zero_native.eq(&a_native), + "a_plus_zero = {:?}, a = {:?}", + a_plus_zero_native.into_repr().as_ref(), + a_native.into_repr().as_ref() + ); + assert!( + a_minus_zero_native.eq(&a_native), + "a_minus_zero = {:?}, a = {:?}", + a_minus_zero_native.into_repr().as_ref(), + a_native.into_repr().as_ref() + ); + assert!( + zero_minus_a_native.eq(&minus_a_native), + "zero_minus_a = {:?}, minus_a = {:?}", + zero_minus_a_native.into_repr().as_ref(), + minus_a_native.into_repr().as_ref() + ); + assert!( + a_times_zero_native.eq(&zero_native), + "a_times_zero = {:?}, zero = {:?}", + a_times_zero_native.into_repr().as_ref(), + zero_native.into_repr().as_ref() + ); + assert!( + zero_plus_a_native.eq(&a_native), + "zero_plus_a = {:?}, a = {:?}", + zero_plus_a_native.into_repr().as_ref(), + a_native.into_repr().as_ref() + ); + assert!( + zero_times_a_native.eq(&zero_native), + "zero_times_a = {:?}, zero = {:?}", + zero_times_a_native.into_repr().as_ref(), + zero_native.into_repr().as_ref() + ); + assert!( + a_times_one_native.eq(&a_native), + "a_times_one = {:?}, a = {:?}", + a_times_one_native.into_repr().as_ref(), + a_native.into_repr().as_ref() + ); + assert!( + one_times_a_native.eq(&a_native), + "one_times_a = {:?}, a = {:?}", + one_times_a_native.into_repr().as_ref(), + a_native.into_repr().as_ref() + ); +} + +fn distribution_law_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let a_native = TargetField::rand(rng); + let b_native = TargetField::rand(rng); + let c_native = TargetField::rand(rng); + + let a_plus_b_native = a_native.clone() + &b_native; + let a_times_c_native = a_native.clone() * &c_native; + let b_times_c_native = b_native.clone() * &c_native; + let a_plus_b_times_c_native = a_plus_b_native.clone() * &c_native; + let a_times_c_plus_b_times_c_native = a_times_c_native + &b_times_c_native; + + assert!( + a_plus_b_times_c_native.eq(&a_times_c_plus_b_times_c_native), + "(a + b) * c doesn't equal (a * c) + (b * c)" + ); + + let a = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "a"), + || Ok(a_native), + ) + .unwrap(); + let b = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "b"), + || Ok(b_native), + ) + .unwrap(); + let c = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "c"), + || Ok(c_native), + ) + .unwrap(); + + let a_plus_b = &a + &b; + let a_times_c = &a * &c; + let b_times_c = &b * &c; + let a_plus_b_times_c = &a_plus_b * &c; + let a_times_c_plus_b_times_c = &a_times_c + &b_times_c; + + assert!( + a_plus_b.value().unwrap().eq(&a_plus_b_native), + "a + b doesn't match" + ); + assert!( + a_times_c.value().unwrap().eq(&a_times_c_native), + "a * c doesn't match" + ); + assert!( + b_times_c.value().unwrap().eq(&b_times_c_native), + "b * c doesn't match" + ); + assert!( + a_plus_b_times_c + .value() + .unwrap() + .eq(&a_plus_b_times_c_native), + "(a + b) * c doesn't match" + ); + assert!( + a_times_c_plus_b_times_c + .value() + .unwrap() + .eq(&a_times_c_plus_b_times_c_native), + "(a * c) + (b * c) doesn't match" + ); + assert!( + a_plus_b_times_c_native.eq(&a_times_c_plus_b_times_c_native), + "(a + b) * c != (a * c) + (b * c)" + ); +} + +fn randomized_arithmetic_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut operations: Vec = Vec::new(); + for _ in 0..TEST_COUNT { + operations.push(rng.next_u32() % 3); + } + + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + for op in operations.iter() { + let next_native = TargetField::rand(rng); + let next = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next num for repetition"), + || Ok(next_native), + ) + .unwrap(); + match op { + 0 => { + num_native += &next_native; + num += &next; + } + 1 => { + num_native *= &next_native; + num *= &next; + } + 2 => { + num_native -= &next_native; + num -= &next; + } + _ => (), + }; + + assert!( + num.value().unwrap().eq(&num_native), + "randomized arithmetic failed:" + ); + } +} + +fn addition_stress_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = + NonNativeFieldVar::new_witness(ark_relations::ns!(cs, "initial num"), || Ok(num_native)) + .unwrap(); + for _ in 0..TEST_COUNT { + let next_native = TargetField::rand(rng); + let next = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next num for repetition"), + || Ok(next_native), + ) + .unwrap(); + num_native += &next_native; + num += &next; + + assert!(num.value().unwrap().eq(&num_native)); + } +} + +fn multiplication_stress_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + for _ in 0..TEST_COUNT { + let next_native = TargetField::rand(rng); + let next = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next num for repetition"), + || Ok(next_native), + ) + .unwrap(); + num_native *= &next_native; + num *= &next; + + assert!(num.value().unwrap().eq(&num_native)); + } +} + +fn mul_and_add_stress_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + for _ in 0..TEST_COUNT { + let next_add_native = TargetField::rand(rng); + let next_add = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next to add num for repetition"), + || Ok(next_add_native), + ) + .unwrap(); + let next_mul_native = TargetField::rand(rng); + let next_mul = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next to mul num for repetition"), + || Ok(next_mul_native), + ) + .unwrap(); + + num_native = num_native * &next_mul_native + &next_add_native; + num = num * &next_mul + &next_add; + + assert!(num.value().unwrap().eq(&num_native)); + } +} + +fn square_mul_add_stress_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + for _ in 0..TEST_COUNT { + let next_add_native = TargetField::rand(rng); + let next_add = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next to add num for repetition"), + || Ok(next_add_native), + ) + .unwrap(); + let next_mul_native = TargetField::rand(rng); + let next_mul = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "next to mul num for repetition"), + || Ok(next_mul_native), + ) + .unwrap(); + + num_native = num_native * &num_native * &next_mul_native + &next_add_native; + num = &num * &num * &next_mul + &next_add; + + assert!(num.value().unwrap().eq(&num_native)); + } +} + +fn double_stress_test_1( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + // Add to at least BaseField::size_in_bits() to ensure that we teat the overflowing + for _ in 0..TEST_COUNT + BaseField::size_in_bits() { + // double + num_native = num_native + &num_native; + num = &num + # + + assert!(num.value().unwrap().eq(&num_native), "result incorrect"); + } +} + +fn double_stress_test_2( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + for _ in 0..TEST_COUNT { + // double + num_native = num_native + &num_native; + num = &num + # + + assert!(num.value().unwrap().eq(&num_native)); + + // square + let num_square_native = num_native * &num_native; + let num_square = &num * # + assert!(num_square.value().unwrap().eq(&num_square_native)); + } +} + +fn double_stress_test_3( + cs: ConstraintSystemRef, + rng: &mut R, +) { + let mut num_native = TargetField::rand(rng); + let mut num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "initial num"), + || Ok(num_native), + ) + .unwrap(); + for _ in 0..TEST_COUNT { + // double + num_native = num_native + &num_native; + num = &num + # + + assert!(num.value().unwrap().eq(&num_native)); + + // square + let num_square_native = num_native * &num_native; + let num_square = &num * # + let num_square_native_gadget = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "repetition: alloc_native num"), + || Ok(num_square_native), + ) + .unwrap(); + + num_square.enforce_equal(&num_square_native_gadget).unwrap(); + } +} + +fn inverse_stress_test( + cs: ConstraintSystemRef, + rng: &mut R, +) { + for _ in 0..TEST_COUNT { + let num_native = TargetField::rand(rng); + let num = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "num"), + || Ok(num_native), + ) + .unwrap(); + + if num_native == TargetField::zero() { + continue; + } + + let num_native_inverse = num_native.inverse().unwrap(); + let num_inverse = num.inverse().unwrap(); + + assert!(num_inverse.value().unwrap().eq(&num_native_inverse)); + } +} + +macro_rules! nonnative_test_individual { + ($test_method:ident, $test_name:ident, $test_target_field:ty, $test_base_field:ty) => { + paste::item! { + #[test] + fn [<$test_method _ $test_name:lower>]() { + let rng = &mut ark_std::test_rng(); + + for _ in 0..NUM_REPETITIONS { + let cs = ConstraintSystem::<$test_base_field>::new_ref(); + $test_method::<$test_target_field, $test_base_field, _>(cs.clone(), rng); + assert!(cs.is_satisfied().unwrap()); + } + } + } + }; +} + +macro_rules! nonnative_test { + ($test_name:ident, $test_target_field:ty, $test_base_field:ty) => { + nonnative_test_individual!( + allocation_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + addition_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + multiplication_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + equality_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + edge_cases_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + distribution_law_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + addition_stress_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + double_stress_test_1, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + double_stress_test_2, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + double_stress_test_3, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + randomized_arithmetic_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + multiplication_stress_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + mul_and_add_stress_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + square_mul_add_stress_test, + $test_name, + $test_target_field, + $test_base_field + ); + nonnative_test_individual!( + inverse_stress_test, + $test_name, + $test_target_field, + $test_base_field + ); + }; +} + +nonnative_test!( + MNT46Small, + ::Fr, + ::Fr +); +nonnative_test!( + MNT64Small, + ::Fr, + ::Fr +); +nonnative_test!( + MNT46Big, + ::Fr, + ::Fr +); +nonnative_test!( + MNT64Big, + ::Fr, + ::Fr +); +nonnative_test!( + BLS12MNT4Small, + ::Fr, + ::Fr +); +nonnative_test!( + BLS12, + ::Fq, + ::Fr +); +#[cfg(not(ci))] +nonnative_test!( + MNT6BigMNT4Small, + ::Fr, + ::Fr +); +nonnative_test!( + PallasFrMNT6Fr, + ark_pallas::Fr, + ::Fr +); +nonnative_test!( + MNT6FrPallasFr, + ::Fr, + ark_pallas::Fr +); +nonnative_test!(PallasFqFr, ark_pallas::Fq, ark_pallas::Fr); +nonnative_test!(PallasFrFq, ark_pallas::Fr, ark_pallas::Fq); diff --git a/tests/from_test.rs b/tests/from_test.rs new file mode 100644 index 0000000..968d7ba --- /dev/null +++ b/tests/from_test.rs @@ -0,0 +1,24 @@ +use ark_r1cs_std::alloc::AllocVar; +use ark_r1cs_std::fields::nonnative::{NonNativeFieldMulResultVar, NonNativeFieldVar}; +use ark_r1cs_std::R1CSVar; +use ark_relations::r1cs::ConstraintSystem; +use ark_std::UniformRand; + +#[test] +fn from_test() { + type F = ark_bls12_377::Fr; + type CF = ark_bls12_377::Fq; + + let mut rng = ark_std::test_rng(); + let cs = ConstraintSystem::::new_ref(); + let f = F::rand(&mut rng); + + let f_var = NonNativeFieldVar::::new_input(cs.clone(), || Ok(f)).unwrap(); + let f_var_converted = NonNativeFieldMulResultVar::::from(&f_var); + let f_var_converted_reduced = f_var_converted.reduce().unwrap(); + + let f_var_value = f_var.value().unwrap(); + let f_var_converted_reduced_value = f_var_converted_reduced.value().unwrap(); + + assert_eq!(f_var_value, f_var_converted_reduced_value); +} diff --git a/tests/to_bytes_test.rs b/tests/to_bytes_test.rs new file mode 100644 index 0000000..e352c05 --- /dev/null +++ b/tests/to_bytes_test.rs @@ -0,0 +1,50 @@ +use ark_ec::PairingEngine; +use ark_ff::{to_bytes, Zero}; +use ark_mnt4_298::MNT4_298; +use ark_mnt6_298::MNT6_298; +use ark_r1cs_std::alloc::AllocVar; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; +use ark_r1cs_std::{R1CSVar, ToBitsGadget, ToBytesGadget}; +use ark_relations::r1cs::ConstraintSystem; + +#[test] +fn to_bytes_test() { + let cs = ConstraintSystem::<::Fr>::new_ref(); + + let target_test_elem = ::Fr::from(123456u128); + let target_test_gadget = NonNativeFieldVar::< + ::Fr, + ::Fr, + >::new_witness(cs, || Ok(target_test_elem)) + .unwrap(); + + let target_to_bytes: Vec = target_test_gadget + .to_bytes() + .unwrap() + .iter() + .map(|v| v.value().unwrap()) + .collect(); + + // 123456 = 65536 + 226 * 256 + 64 + assert_eq!(target_to_bytes[0], 64); + assert_eq!(target_to_bytes[1], 226); + assert_eq!(target_to_bytes[2], 1); + + for byte in target_to_bytes.iter().skip(3) { + assert_eq!(*byte, 0); + } + + assert_eq!(to_bytes!(target_test_elem).unwrap(), target_to_bytes); +} + +#[test] +fn to_bits_test() { + type F = ark_bls12_377::Fr; + type CF = ark_bls12_377::Fq; + + let cs = ConstraintSystem::::new_ref(); + let f = F::zero(); + + let f_var = NonNativeFieldVar::::new_input(cs.clone(), || Ok(f)).unwrap(); + f_var.to_bits_le().unwrap(); +} diff --git a/tests/to_constraint_field_test.rs b/tests/to_constraint_field_test.rs new file mode 100644 index 0000000..7dc17de --- /dev/null +++ b/tests/to_constraint_field_test.rs @@ -0,0 +1,28 @@ +use ark_r1cs_std::alloc::AllocVar; +use ark_r1cs_std::fields::nonnative::NonNativeFieldVar; +use ark_r1cs_std::{R1CSVar, ToConstraintFieldGadget}; +use ark_relations::r1cs::ConstraintSystem; + +#[test] +fn to_constraint_field_test() { + type F = ark_bls12_377::Fr; + type CF = ark_bls12_377::Fq; + + let cs = ConstraintSystem::::new_ref(); + + let a = NonNativeFieldVar::Constant(F::from(12u8)); + let b = NonNativeFieldVar::new_input(cs.clone(), || Ok(F::from(6u8))).unwrap(); + + let b2 = &b + &b; + + let a_to_constraint_field = a.to_constraint_field().unwrap(); + let b2_to_constraint_field = b2.to_constraint_field().unwrap(); + + assert_eq!(a_to_constraint_field.len(), b2_to_constraint_field.len()); + for (left, right) in a_to_constraint_field + .iter() + .zip(b2_to_constraint_field.iter()) + { + assert_eq!(left.value(), right.value()); + } +}