diff --git a/r1cs-std/src/fields/fp/mod.rs b/r1cs-std/src/fields/fp/mod.rs index 6648864..47d592b 100644 --- a/r1cs-std/src/fields/fp/mod.rs +++ b/r1cs-std/src/fields/fp/mod.rs @@ -22,6 +22,22 @@ impl FpGadget { pub fn from>(mut cs: CS, value: &F) -> Self { Self::alloc(cs.ns(|| "from"), || Ok(*value)).unwrap() } + + fn is_constant(&self) -> bool { + match &self.variable { + // If you don't do alloc_constant, you are guaranteed to get a variable, + // hence we assume that all variables are not the constant variable. + // Technically this omits recognizing some constants. + // E.g. given variables w,x,y,z with constraints: + // w = x + 1 + // y = -x + 1 + // and then created the variable z = w + y, + // this would not recognize that z is in fact a constant. + // Since this is an edge case, this is left as a TODO. + Var(_v) => false, + LC(l) => l.is_constant(), + } + } } impl FieldGadget for FpGadget { @@ -145,6 +161,14 @@ impl FieldGadget for FpGadget { mut cs: CS, other: &Self, ) -> Result { + // Apply constant folding if it applies + // unwrap is used, because these values are guaranteed to exist. + if other.is_constant() { + return self.mul_by_constant(cs, &other.get_value().unwrap()); + } else if self.is_constant() { + return other.mul_by_constant(cs, &self.get_value().unwrap()); + } + let product = Self::alloc(cs.ns(|| "mul"), || { Ok(self.value.get()? * &other.value.get()?) })?; diff --git a/r1cs-std/src/fields/mod.rs b/r1cs-std/src/fields/mod.rs index 51e254f..9807958 100644 --- a/r1cs-std/src/fields/mod.rs +++ b/r1cs-std/src/fields/mod.rs @@ -290,6 +290,7 @@ pub(crate) mod tests { let b_native = FE::rand(&mut rng); let a = F::alloc(&mut cs.ns(|| "generate_a"), || Ok(a_native)).unwrap(); let b = F::alloc(&mut cs.ns(|| "generate_b"), || Ok(b_native)).unwrap(); + let b_const = F::alloc_constant(&mut cs.ns(|| "generate_b_as_constant"), b_native).unwrap(); let zero = F::zero(cs.ns(|| "zero")).unwrap(); let zero_native = zero.get_value().unwrap(); @@ -398,6 +399,12 @@ pub(crate) mod tests { assert_eq!(ab, ba); assert_eq!(ab.get_value().unwrap(), a_native * &b_native); + let ab_const = a.mul(cs.ns(|| "a_times_b_const"), &b_const).unwrap(); + let b_const_a = b_const.mul(cs.ns(|| "b_const_times_a"), &a).unwrap(); + assert_eq!(ab_const, b_const_a); + assert_eq!(ab_const, ab); + assert_eq!(ab_const.get_value().unwrap(), a_native * &b_native); + // (a * b) * a = a * (b * a) let ab_a = ab.mul(cs.ns(|| "ab_times_a"), &a).unwrap(); let a_ba = a.mul(cs.ns(|| "a_times_ba"), &ba).unwrap();