diff --git a/crypto-primitives/src/merkle_tree/constraints.rs b/crypto-primitives/src/merkle_tree/constraints.rs index 873d3c6..0c5bdd3 100644 --- a/crypto-primitives/src/merkle_tree/constraints.rs +++ b/crypto-primitives/src/merkle_tree/constraints.rs @@ -37,7 +37,7 @@ where // proof. let leaf_bits = leaf.to_bytes()?; let leaf_hash = CRHGadget::evaluate(parameters, &leaf_bits)?; - let cs = leaf_hash.cs().or(root.cs()).unwrap(); + let cs = leaf_hash.cs().or(root.cs()); // Check if leaf is one of the bottom-most siblings. let leaf_is_left = Boolean::new_witness(r1cs_core::ns!(cs, "leaf_is_left"), || { diff --git a/crypto-primitives/src/prf/blake2s/constraints.rs b/crypto-primitives/src/prf/blake2s/constraints.rs index 4a4cde9..ce07bd2 100644 --- a/crypto-primitives/src/prf/blake2s/constraints.rs +++ b/crypto-primitives/src/prf/blake2s/constraints.rs @@ -350,7 +350,7 @@ impl AllocVar<[u8; 32], ConstraintF> for OutputVar R1CSVar for OutputVar { type Value = [u8; 32]; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { self.0.cs() } diff --git a/r1cs-std/src/bits/boolean.rs b/r1cs-std/src/bits/boolean.rs index 995243a..f78762e 100644 --- a/r1cs-std/src/bits/boolean.rs +++ b/r1cs-std/src/bits/boolean.rs @@ -232,10 +232,10 @@ pub enum Boolean { impl R1CSVar for Boolean { type Value = bool; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { match self { - Self::Is(a) | Self::Not(a) => Some(a.cs.clone()), - _ => None, + Self::Is(a) | Self::Not(a) => a.cs.clone(), + _ => ConstraintSystemRef::None, } } @@ -598,11 +598,10 @@ impl Boolean { match r { Constant(true) => Ok(()), Constant(false) => Err(SynthesisError::AssignmentMissing), - Is(_) | Not(_) => r.cs().unwrap().enforce_constraint( - r.lc(), - lc!() + Variable::One, - lc!() + Variable::One, - ), + Is(_) | Not(_) => { + r.cs() + .enforce_constraint(r.lc(), lc!() + Variable::One, lc!() + Variable::One) + } } } @@ -778,7 +777,7 @@ impl EqGadget for Boolean { }; if condition != &Constant(false) { - let cs = self.cs().or(other.cs()).or(condition.cs()).unwrap(); + let cs = self.cs().or(other.cs()).or(condition.cs()); cs.enforce_constraint(lc!() + difference, condition.lc(), lc!())?; } Ok(()) @@ -814,11 +813,7 @@ impl EqGadget for Boolean { }; if should_enforce != &Constant(false) { - let cs = self - .cs() - .or(other.cs()) - .or(should_enforce.cs()) - .ok_or(SynthesisError::UnconstrainedVariable)?; + let cs = self.cs().or(other.cs()).or(should_enforce.cs()); cs.enforce_constraint(difference, should_enforce.lc(), should_enforce.lc())?; } Ok(()) @@ -863,7 +858,7 @@ impl CondSelectGadget for Boolean { (&Constant(true), x) => cond.or(x), (x, &Constant(true)) => cond.not().or(x), (a, b) => { - let cs = cond.cs().unwrap(); + let cs = cond.cs(); let result: Boolean = AllocatedBit::new_witness_without_booleanity_check(cs.clone(), || { let cond = cond.value()?; diff --git a/r1cs-std/src/bits/uint.rs b/r1cs-std/src/bits/uint.rs index 920f235..df0c06a 100644 --- a/r1cs-std/src/bits/uint.rs +++ b/r1cs-std/src/bits/uint.rs @@ -38,7 +38,7 @@ macro_rules! make_uint { impl R1CSVar for $name { type Value = $native; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { self.bits.as_slice().cs() } @@ -254,7 +254,7 @@ macro_rules! make_uint { return Ok($name::constant(modular_value.unwrap())); } - let cs = operands.cs().unwrap(); + let cs = operands.cs(); // Storage area for the resulting bits let mut result_bits = vec![]; diff --git a/r1cs-std/src/bits/uint8.rs b/r1cs-std/src/bits/uint8.rs index 7b2676e..6facd9b 100644 --- a/r1cs-std/src/bits/uint8.rs +++ b/r1cs-std/src/bits/uint8.rs @@ -18,7 +18,7 @@ pub struct UInt8 { impl R1CSVar for UInt8 { type Value = u8; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { self.bits.as_slice().cs() } diff --git a/r1cs-std/src/eq.rs b/r1cs-std/src/eq.rs index 8c6db70..80798ae 100644 --- a/r1cs-std/src/eq.rs +++ b/r1cs-std/src/eq.rs @@ -108,16 +108,16 @@ impl + R1CSVar, F: Field> EqGadget for [T] { ) -> Result<(), SynthesisError> { assert_eq!(self.len(), other.len()); let some_are_different = self.is_neq(other)?; - if let Some(cs) = some_are_different.cs().or(should_enforce.cs()) { + if [&some_are_different, should_enforce].is_constant() { + assert!(some_are_different.value().unwrap()); + Ok(()) + } else { + let cs = [&some_are_different, should_enforce].cs(); cs.enforce_constraint( some_are_different.lc(), should_enforce.lc(), should_enforce.lc(), ) - } else { - // `some_are_different` and `should_enforce` are both constants - assert!(some_are_different.value().unwrap()); - Ok(()) } } } diff --git a/r1cs-std/src/fields/cubic_extension.rs b/r1cs-std/src/fields/cubic_extension.rs index 8c2ef13..0a96bf2 100644 --- a/r1cs-std/src/fields/cubic_extension.rs +++ b/r1cs-std/src/fields/cubic_extension.rs @@ -9,7 +9,7 @@ use crate::fields::fp::FpVar; use crate::{ fields::{FieldOpsBounds, FieldVar}, prelude::*, - Assignment, ToConstraintFieldGadget, Vec, + ToConstraintFieldGadget, Vec, }; /// This struct is the `R1CS` equivalent of the cubic extension field type @@ -89,7 +89,7 @@ where { type Value = CubicExtField

; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { [&self.c0, &self.c1, &self.c2].cs() } @@ -272,7 +272,7 @@ where AllocationMode::Witness }; let inverse = Self::new_variable( - self.cs().get()?.clone(), + self.cs(), || { self.value() .map(|f| f.inverse().unwrap_or(CubicExtField::zero())) diff --git a/r1cs-std/src/fields/fp/cmp.rs b/r1cs-std/src/fields/fp/cmp.rs index 5c61a11..e346f9b 100644 --- a/r1cs-std/src/fields/fp/cmp.rs +++ b/r1cs-std/src/fields/fp/cmp.rs @@ -140,10 +140,11 @@ impl FpVar { /// Helper function to enforce `self < other`. This function assumes `self` and `other` /// are `<= (p-1)/2` and does not generate constraints to verify that. fn enforce_smaller_than_unchecked(&self, other: &FpVar) -> Result<(), SynthesisError> { - let cs = [self, other].cs().unwrap(); let is_smaller_than = self.is_smaller_than_unchecked(other)?; let lc_one = lc!() + Variable::One; - cs.enforce_constraint(is_smaller_than.lc(), lc_one.clone(), lc_one) + [self, other] + .cs() + .enforce_constraint(is_smaller_than.lc(), lc_one.clone(), lc_one) } } diff --git a/r1cs-std/src/fields/fp/mod.rs b/r1cs-std/src/fields/fp/mod.rs index 3614ead..1fa2266 100644 --- a/r1cs-std/src/fields/fp/mod.rs +++ b/r1cs-std/src/fields/fp/mod.rs @@ -46,10 +46,10 @@ pub enum FpVar { impl R1CSVar for FpVar { type Value = F; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { match self { - Self::Constant(_) => Some(ConstraintSystemRef::None), - Self::Var(a) => Some(a.cs.clone()), + Self::Constant(_) => ConstraintSystemRef::None, + Self::Var(a) => a.cs.clone(), } } @@ -67,7 +67,7 @@ impl From> for FpVar { Self::Constant(F::from(b as u8)) } else { // `other` is a variable - let cs = other.cs().unwrap(); + let cs = other.cs(); let variable = cs.new_lc(other.lc()).unwrap(); Self::Var(AllocatedFp::new( other.value().ok().map(|b| F::from(b as u8)), @@ -90,12 +90,9 @@ impl<'a, F: PrimeField> FieldOpsBounds<'a, F, FpVar> for &'a FpVar {} impl AllocatedFp { /// Constructs `Self` from a `Boolean`: if `other` is false, this outputs `zero`, else it outputs `one`. pub fn from(other: Boolean) -> Self { - if let Some(cs) = other.cs() { - let variable = cs.new_lc(other.lc()).unwrap(); - Self::new(other.value().ok().map(|b| F::from(b as u8)), variable, cs) - } else { - unreachable!("Cannot create a constant value") - } + let cs = other.cs(); + let variable = cs.new_lc(other.lc()).unwrap(); + Self::new(other.value().ok().map(|b| F::from(b as u8)), variable, cs) } /// Returns the value assigned to `self` in the underlying constraint system @@ -511,7 +508,7 @@ impl CondSelectGadget for AllocatedFp { Boolean::Constant(true) => Ok(true_val.clone()), Boolean::Constant(false) => Ok(false_val.clone()), _ => { - let cs = cond.cs().unwrap(); + let cs = cond.cs(); let result = Self::new_witness(cs.clone(), || { cond.value() .and_then(|c| if c { true_val } else { false_val }.value.get()) @@ -541,24 +538,20 @@ impl TwoBitLookupGadget for AllocatedFp { fn two_bit_lookup(b: &[Boolean], c: &[Self::TableConstant]) -> Result { debug_assert_eq!(b.len(), 2); debug_assert_eq!(c.len(), 4); - if let Some(cs) = b.cs() { - let result = Self::new_witness(cs.clone(), || { - let lsb = usize::from(b[0].value()?); - let msb = usize::from(b[1].value()?); - let index = lsb + (msb << 1); - Ok(c[index]) - })?; - let one = Variable::One; - cs.enforce_constraint( - lc!() + b[1].lc() * (c[3] - &c[2] - &c[1] + &c[0]) + (c[1] - &c[0], one), - lc!() + b[0].lc(), - lc!() + result.variable - (c[0], one) + b[1].lc() * (c[0] - &c[2]), - )?; - - Ok(result) - } else { - unreachable!("must provide a way to obtain a ConstraintSystemRef") - } + let result = Self::new_witness(b.cs(), || { + let lsb = usize::from(b[0].value()?); + let msb = usize::from(b[1].value()?); + let index = lsb + (msb << 1); + Ok(c[index]) + })?; + let one = Variable::One; + b.cs().enforce_constraint( + lc!() + b[1].lc() * (c[3] - &c[2] - &c[1] + &c[0]) + (c[1] - &c[0], one), + lc!() + b[0].lc(), + lc!() + result.variable - (c[0], one) + b[1].lc() * (c[0] - &c[2]), + )?; + + Ok(result) } } @@ -573,37 +566,32 @@ impl ThreeBitCondNegLookupGadget for AllocatedFp { ) -> Result { debug_assert_eq!(b.len(), 3); debug_assert_eq!(c.len(), 4); + let result = Self::new_witness(b.cs(), || { + let lsb = usize::from(b[0].value()?); + let msb = usize::from(b[1].value()?); + let index = lsb + (msb << 1); + let intermediate = c[index]; - if let Some(cs) = b.cs() { - let result = Self::new_witness(cs.clone(), || { - let lsb = usize::from(b[0].value()?); - let msb = usize::from(b[1].value()?); - let index = lsb + (msb << 1); - let intermediate = c[index]; - - let is_negative = b[2].value()?; - let y = if is_negative { - -intermediate - } else { - intermediate - }; - Ok(y) - })?; - - let y_lc = b0b1.lc() * (c[3] - &c[2] - &c[1] + &c[0]) - + b[0].lc() * (c[1] - &c[0]) - + b[1].lc() * (c[2] - &c[0]) - + (c[0], Variable::One); - cs.enforce_constraint( - y_lc.clone() + y_lc.clone(), - b[2].lc(), - y_lc.clone() - result.variable, - )?; - - Ok(result) - } else { - unreachable!("must provide a way to obtain a ConstraintSystemRef") - } + let is_negative = b[2].value()?; + let y = if is_negative { + -intermediate + } else { + intermediate + }; + Ok(y) + })?; + + let y_lc = b0b1.lc() * (c[3] - &c[2] - &c[1] + &c[0]) + + b[0].lc() * (c[1] - &c[0]) + + b[1].lc() * (c[2] - &c[0]) + + (c[0], Variable::One); + b.cs().enforce_constraint( + y_lc.clone() + y_lc.clone(), + b[2].lc(), + y_lc.clone() - result.variable, + )?; + + Ok(result) } } @@ -938,7 +926,7 @@ impl CondSelectGadget for FpVar { Ok(is.mul_constant(*t).add(¬.mul_constant(*f)).into()) } (_, _) => { - let cs = cond.cs().unwrap(); + let cs = cond.cs(); let true_value = match true_value { Self::Constant(f) => AllocatedFp::new_constant(cs.clone(), f)?, Self::Var(v) => v.clone(), @@ -964,13 +952,13 @@ impl TwoBitLookupGadget for FpVar { fn two_bit_lookup(b: &[Boolean], c: &[Self::TableConstant]) -> Result { debug_assert_eq!(b.len(), 2); debug_assert_eq!(c.len(), 4); - if b.cs().is_some() { - AllocatedFp::two_bit_lookup(b, c).map(Self::Var) - } else { + if b.is_constant() { let lsb = usize::from(b[0].value()?); let msb = usize::from(b[1].value()?); let index = lsb + (msb << 1); Ok(Self::Constant(c[index])) + } else { + AllocatedFp::two_bit_lookup(b, c).map(Self::Var) } } } @@ -987,9 +975,9 @@ impl ThreeBitCondNegLookupGadget for FpVar { debug_assert_eq!(b.len(), 3); debug_assert_eq!(c.len(), 4); - if b.cs().or(b0b1.cs()).is_some() { - AllocatedFp::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var) - } else { + if !b.cs().or(b0b1.cs()).is_none() { + // We only have constants + let lsb = usize::from(b[0].value()?); let msb = usize::from(b[1].value()?); let index = lsb + (msb << 1); @@ -1002,6 +990,8 @@ impl ThreeBitCondNegLookupGadget for FpVar { intermediate }; Ok(Self::Constant(y)) + } else { + AllocatedFp::three_bit_cond_neg_lookup(b, b0b1, c).map(Self::Var) } } } diff --git a/r1cs-std/src/fields/mod.rs b/r1cs-std/src/fields/mod.rs index 1d3d2f9..5320e1e 100644 --- a/r1cs-std/src/fields/mod.rs +++ b/r1cs-std/src/fields/mod.rs @@ -152,7 +152,7 @@ pub trait FieldVar: /// It is up to the caller to ensure that denominator is non-zero, /// since in that case the result is unconstrained. fn mul_by_inverse(&self, denominator: &Self) -> Result { - let result = Self::new_witness(self.cs().unwrap(), || { + let result = Self::new_witness(self.cs(), || { let denominator_inv_native = denominator.value()?.inverse().get()?; let result = self.value()? * &denominator_inv_native; Ok(result) diff --git a/r1cs-std/src/fields/quadratic_extension.rs b/r1cs-std/src/fields/quadratic_extension.rs index bf2c9ae..4b1cb30 100644 --- a/r1cs-std/src/fields/quadratic_extension.rs +++ b/r1cs-std/src/fields/quadratic_extension.rs @@ -9,7 +9,7 @@ use crate::fields::fp::FpVar; use crate::{ fields::{FieldOpsBounds, FieldVar}, prelude::*, - Assignment, ToConstraintFieldGadget, Vec, + ToConstraintFieldGadget, Vec, }; /// This struct is the `R1CS` equivalent of the quadratic extension field type @@ -122,7 +122,7 @@ where { type Value = QuadExtField

; - fn cs(&self) -> Option> { + fn cs(&self) -> ConstraintSystemRef { [&self.c0, &self.c1].cs() } @@ -279,7 +279,7 @@ where AllocationMode::Witness }; let inverse = Self::new_variable( - self.cs().get()?.clone(), + self.cs(), || { self.value() .map(|f| f.inverse().unwrap_or(QuadExtField::zero())) diff --git a/r1cs-std/src/groups/curves/short_weierstrass/mod.rs b/r1cs-std/src/groups/curves/short_weierstrass/mod.rs index 21091fc..b73470e 100644 --- a/r1cs-std/src/groups/curves/short_weierstrass/mod.rs +++ b/r1cs-std/src/groups/curves/short_weierstrass/mod.rs @@ -119,7 +119,7 @@ where { type Value = SWProjective

; - fn cs(&self) -> Option::BasePrimeField>> { + fn cs(&self) -> ConstraintSystemRef<::BasePrimeField> { self.x.cs().or(self.y.cs()).or(self.z.cs()) } @@ -152,7 +152,7 @@ where /// Convert this point into affine form. #[tracing::instrument(target = "r1cs")] pub fn to_affine(&self) -> Result, SynthesisError> { - let cs = self.cs().unwrap_or(ConstraintSystemRef::None); + let cs = self.cs(); let mode = if self.is_constant() { let point = self.value()?.into_affine(); let x = F::new_constant(ConstraintSystemRef::None, point.x)?; diff --git a/r1cs-std/src/groups/curves/twisted_edwards/mod.rs b/r1cs-std/src/groups/curves/twisted_edwards/mod.rs index 77e205f..f8dc5ce 100644 --- a/r1cs-std/src/groups/curves/twisted_edwards/mod.rs +++ b/r1cs-std/src/groups/curves/twisted_edwards/mod.rs @@ -49,7 +49,7 @@ mod montgomery_affine_impl { { type Value = (P::BaseField, P::BaseField); - fn cs(&self) -> Option::BasePrimeField>> { + fn cs(&self) -> ConstraintSystemRef<::BasePrimeField> { self.x.cs().or(self.y.cs()) } @@ -112,7 +112,7 @@ mod montgomery_affine_impl { /// Converts `self` into a Twisted Edwards curve point variable. #[tracing::instrument(target = "r1cs")] pub fn into_edwards(&self) -> Result, SynthesisError> { - let cs = self.cs().unwrap_or(ConstraintSystemRef::None); + let cs = self.cs(); // Compute u = x / y let u = F::new_witness(r1cs_core::ns!(cs, "u"), || { let y_inv = self @@ -153,12 +153,11 @@ mod montgomery_affine_impl { #[tracing::instrument(target = "r1cs")] fn add(self, other: &'a Self) -> Self::Output { let cs = [&self, other].cs(); - let mode = if cs.is_none() || matches!(cs, Some(ConstraintSystemRef::None)) { + let mode = if cs.is_none() { AllocationMode::Constant } else { AllocationMode::Witness }; - let cs = cs.unwrap_or(ConstraintSystemRef::None); let coeff_b = P::MontgomeryModelParameters::COEFF_B; let coeff_a = P::MontgomeryModelParameters::COEFF_A; @@ -378,7 +377,7 @@ where { type Value = TEProjective

; - fn cs(&self) -> Option::BasePrimeField>> { + fn cs(&self) -> ConstraintSystemRef<::BasePrimeField> { self.x.cs().or(self.y.cs()) } @@ -465,7 +464,7 @@ where let value = self.value()?; *self = Self::constant(value.double()); } else { - let cs = self.cs().unwrap(); + let cs = self.cs(); let a = P::COEFF_A; // xy @@ -714,7 +713,7 @@ impl_bounded_ops!( assert!(this.is_constant() && other.is_constant()); AffineVar::constant(this.value().unwrap() + &other.value().unwrap()) } else { - let cs = [this, other].cs().unwrap(); + let cs = [this, other].cs(); let a = P::COEFF_A; let d = P::COEFF_D; diff --git a/r1cs-std/src/lib.rs b/r1cs-std/src/lib.rs index 27bae96..e017234 100644 --- a/r1cs-std/src/lib.rs +++ b/r1cs-std/src/lib.rs @@ -119,12 +119,13 @@ pub trait R1CSVar { type Value: core::fmt::Debug + Eq + Clone; /// Returns the underlying `ConstraintSystemRef`. - fn cs(&self) -> Option>; + /// + /// If `self` is a constant value, then this *must* return `r1cs_core::ConstraintSystemRef::None`. + fn cs(&self) -> r1cs_core::ConstraintSystemRef; /// Returns `true` if `self` is a circuit-generation-time constant. fn is_constant(&self) -> bool { - self.cs() - .map_or(true, |cs| cs == r1cs_core::ConstraintSystemRef::None) + self.cs().is_none() } /// Returns the value that is assigned to `self` in the underlying @@ -135,8 +136,8 @@ pub trait R1CSVar { impl> R1CSVar for [T] { type Value = Vec; - fn cs(&self) -> Option> { - let mut result = None; + fn cs(&self) -> r1cs_core::ConstraintSystemRef { + let mut result = r1cs_core::ConstraintSystemRef::None; for var in self { result = var.cs().or(result); } @@ -155,7 +156,7 @@ impl> R1CSVar for [T] { impl<'a, F: Field, T: 'a + R1CSVar> R1CSVar for &'a T { type Value = T::Value; - fn cs(&self) -> Option> { + fn cs(&self) -> r1cs_core::ConstraintSystemRef { (*self).cs() }