Browse Source

Add constant folding to FpGadget<F> (#226)

master
Dev Ojha 4 years ago
committed by GitHub
parent
commit
75439b9b4d
2 changed files with 31 additions and 0 deletions
  1. +24
    -0
      r1cs-std/src/fields/fp/mod.rs
  2. +7
    -0
      r1cs-std/src/fields/mod.rs

+ 24
- 0
r1cs-std/src/fields/fp/mod.rs

@ -22,6 +22,22 @@ impl FpGadget {
pub fn from<CS: ConstraintSystem<F>>(mut cs: CS, value: &F) -> Self {
Self::alloc(cs.ns(|| "from"), || Ok(*value)).unwrap()
}
fn is_constant(&self) -> bool {
match &self.variable {
// If you don't do alloc_constant, you are guaranteed to get a variable,
// hence we assume that all variables are not the constant variable.
// Technically this omits recognizing some constants.
// E.g. given variables w,x,y,z with constraints:
// w = x + 1
// y = -x + 1
// and then created the variable z = w + y,
// this would not recognize that z is in fact a constant.
// Since this is an edge case, this is left as a TODO.
Var(_v) => false,
LC(l) => l.is_constant(),
}
}
}
impl<F: PrimeField> FieldGadget<F, F> for FpGadget<F> {
@ -145,6 +161,14 @@ impl FieldGadget for FpGadget {
mut cs: CS,
other: &Self,
) -> Result<Self, SynthesisError> {
// Apply constant folding if it applies
// unwrap is used, because these values are guaranteed to exist.
if other.is_constant() {
return self.mul_by_constant(cs, &other.get_value().unwrap());
} else if self.is_constant() {
return other.mul_by_constant(cs, &self.get_value().unwrap());
}
let product = Self::alloc(cs.ns(|| "mul"), || {
Ok(self.value.get()? * &other.value.get()?)
})?;

+ 7
- 0
r1cs-std/src/fields/mod.rs

@ -290,6 +290,7 @@ pub(crate) mod tests {
let b_native = FE::rand(&mut rng);
let a = F::alloc(&mut cs.ns(|| "generate_a"), || Ok(a_native)).unwrap();
let b = F::alloc(&mut cs.ns(|| "generate_b"), || Ok(b_native)).unwrap();
let b_const = F::alloc_constant(&mut cs.ns(|| "generate_b_as_constant"), b_native).unwrap();
let zero = F::zero(cs.ns(|| "zero")).unwrap();
let zero_native = zero.get_value().unwrap();
@ -398,6 +399,12 @@ pub(crate) mod tests {
assert_eq!(ab, ba);
assert_eq!(ab.get_value().unwrap(), a_native * &b_native);
let ab_const = a.mul(cs.ns(|| "a_times_b_const"), &b_const).unwrap();
let b_const_a = b_const.mul(cs.ns(|| "b_const_times_a"), &a).unwrap();
assert_eq!(ab_const, b_const_a);
assert_eq!(ab_const, ab);
assert_eq!(ab_const.get_value().unwrap(), a_native * &b_native);
// (a * b) * a = a * (b * a)
let ab_a = ab.mul(cs.ns(|| "ab_times_a"), &a).unwrap();
let a_ba = a.mul(cs.ns(|| "a_times_ba"), &ba).unwrap();

Loading…
Cancel
Save