From 2a8974e442a387636ac3fded8720ec6e4f6df169 Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Thu, 12 Dec 2019 01:40:41 -0800 Subject: [PATCH] Boolean conditional select --- r1cs-std/src/bits/boolean.rs | 217 +++++++++++++++--- .../groups/curves/short_weierstrass/mod.rs | 82 ++++--- 2 files changed, 231 insertions(+), 68 deletions(-) diff --git a/r1cs-std/src/bits/boolean.rs b/r1cs-std/src/bits/boolean.rs index f70ba1c..0bcbacf 100644 --- a/r1cs-std/src/bits/boolean.rs +++ b/r1cs-std/src/bits/boolean.rs @@ -1,7 +1,7 @@ use algebra::{BitIterator, Field, FpParameters, PrimeField}; use crate::{prelude::*, Assignment}; -use r1cs_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable}; +use r1cs_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable, ConstraintVar}; use std::borrow::Borrow; /// Represents a variable in the constraint system which is guaranteed @@ -114,12 +114,39 @@ impl AllocatedBit { /// Performs an OR operation over the two operands, returning /// an `AllocatedBit`. - pub fn or(cs: CS, a: &Self, b: &Self) -> Result + pub fn or(mut cs: CS, a: &Self, b: &Self) -> Result where ConstraintF: Field, CS: ConstraintSystem, { - Self::conditionally_select(cs, &Boolean::from(*a), a, b) + let mut result_value = None; + + let result_var = cs.alloc( + || "or result", + || { + if a.value.get()? | b.value.get()? { + result_value = Some(true); + Ok(ConstraintF::one()) + } else { + result_value = Some(false); + Ok(ConstraintF::zero()) + } + }, + )?; + + // Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff + // a and b are both false, and otherwise c is 0. + cs.enforce( + || "nor constraint", + |lc| lc + CS::one() - a.variable, + |lc| lc + CS::one() - b.variable, + |lc| lc + CS::one() - result_var, + ); + + Ok(AllocatedBit { + variable: result_var, + value: result_value, + }) } /// Calculates `a AND (NOT b)`. @@ -283,40 +310,17 @@ impl AllocGadget for AllocatedBit { impl CondSelectGadget for AllocatedBit { fn conditionally_select>( - mut cs: CS, + cs: CS, cond: &Boolean, first: &Self, second: &Self, ) -> Result { - let result = Self::alloc(cs.ns(|| ""), || { - cond.get_value() - .and_then(|cond| { - { - if cond { - first - } else { - second - } - } - .get_value() - }) - .get() - })?; - - // a = self; b = other; c = cond; - // - // r = c * a + (1 - c) * b - // r = b + c * (a - b) - // c * (a - b) = r - b - let one = CS::one(); - cs.enforce( - || "conditionally_select", - |_| cond.lc(one, ConstraintF::one()), - |lc| lc + first.variable - second.variable, - |lc| lc + result.variable - second.variable, - ); - - Ok(result) + cond_select_helper( + cs, + cond, + (first.value, first.variable), + (second.value, second.variable), + ) } fn cost() -> usize { @@ -324,6 +328,40 @@ impl CondSelectGadget for AllocatedBit { } } +fn cond_select_helper>( + mut cs: CS, + cond: &Boolean, + first: (Option, impl Into>), + second: (Option, impl Into>), +) -> Result { + let mut result_val = None; + let result_var = cs.alloc( + || "cond_select_result", + || { + result_val = cond.get_value().and_then(|c| if c { first.0 } else { second.0 }); + result_val.get().map(|v| F::from(v as u8)) + })?; + + let first_var = first.1.into(); + let second_var = second.1.into(); + + // a = self; b = other; c = cond; + // + // r = c * a + (1 - c) * b + // r = b + c * (a - b) + // c * (a - b) = r - b + let one = CS::one(); + cs.enforce( + || "conditionally_select", + |_| cond.lc(one, F::one()), + |lc| (&first_var - &second_var) + lc, + |lc| ConstraintVar::from(result_var) - &second_var + lc, + ); + + Ok(AllocatedBit { value: result_val, variable: result_var }) + +} + /// This is a boolean value which may be either a constant or /// an interpretation of an `AllocatedBit`. #[derive(Copy, Clone, Debug)] @@ -747,6 +785,54 @@ impl ToBytesGadget for Boolean { self.to_bytes(cs) } } +impl CondSelectGadget for Boolean { + fn conditionally_select( + mut cs: CS, + cond: &Self, + first: &Self, + second: &Self, + ) -> Result + where + CS: ConstraintSystem, + { + match cond { + Boolean::Constant(true) => Ok(first.clone()), + Boolean::Constant(false) => Ok(second.clone()), + cond @ Boolean::Not(_) => Self::conditionally_select(cs, &cond.not(), second, first), + cond @ Boolean::Is(_) => { + match (first, second) { + (x, &Boolean::Constant(false)) => { + Boolean::and(cs.ns(|| "and"), cond, x).into() + }, + (&Boolean::Constant(false), x) => { + Boolean::and(cs.ns(|| "and"), &cond.not(), x) + }, + (&Boolean::Constant(true), x) => { + Boolean::or(cs.ns(|| "or"), cond, x).into() + }, + (x, &Boolean::Constant(true)) => { + Boolean::or(cs.ns(|| "or"), &cond.not(), x) + }, + (a @ Boolean::Is(_), b @ Boolean::Is(_)) + | (a @ Boolean::Not(_), b @ Boolean::Not(_)) + | (a @ Boolean::Is(_), b @ Boolean::Not(_)) + | (a @ Boolean::Not(_), b @ Boolean::Is(_)) => { + let a_lc = a.lc(CS::one(), ConstraintF::one()); + let b_lc = b.lc(CS::one(), ConstraintF::one()); + Ok(cond_select_helper(cs, cond, (a.get_value(), a_lc), (b.get_value(), b_lc))?.into()) + }, + } + + } + } + } + + fn cost() -> usize { + 1 + } +} + + #[cfg(test)] mod test { @@ -1288,6 +1374,71 @@ mod test { } } + #[test] + fn test_boolean_cond_select() { + let variants = [ + OperandType::True, + OperandType::False, + OperandType::AllocatedTrue, + OperandType::AllocatedFalse, + OperandType::NegatedAllocatedTrue, + OperandType::NegatedAllocatedFalse, + ]; + + for condition in variants.iter().cloned() { + for first_operand in variants.iter().cloned() { + for second_operand in variants.iter().cloned() { + let mut cs = TestConstraintSystem::::new(); + + let cond; + let a; + let b; + + { + let mut dyn_construct = |operand, name| { + let cs = cs.ns(|| name); + + match operand { + OperandType::True => Boolean::constant(true), + OperandType::False => Boolean::constant(false), + OperandType::AllocatedTrue => { + Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()) + }, + OperandType::AllocatedFalse => { + Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()) + }, + OperandType::NegatedAllocatedTrue => { + Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()).not() + }, + OperandType::NegatedAllocatedFalse => { + Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()).not() + }, + } + }; + + cond = dyn_construct(condition, "cond"); + a = dyn_construct(first_operand, "a"); + b = dyn_construct(second_operand, "b"); + } + + let before = cs.num_constraints(); + let c = Boolean::conditionally_select(&mut cs, &cond, &a, &b).unwrap(); + let after = cs.num_constraints(); + + assert!( + cs.is_satisfied(), + "failed with operands: cond: {:?}, a: {:?}, b: {:?}", + condition, + first_operand, + second_operand, + ); + assert_eq!(c.get_value(), if cond.get_value().unwrap() { a.get_value() } else { b.get_value() }); + assert!(>::cost() >= after - before); + } + } + } + } + #[test] fn test_boolean_or() { let variants = [ diff --git a/r1cs-std/src/groups/curves/short_weierstrass/mod.rs b/r1cs-std/src/groups/curves/short_weierstrass/mod.rs index f2a0ecf..55ad019 100644 --- a/r1cs-std/src/groups/curves/short_weierstrass/mod.rs +++ b/r1cs-std/src/groups/curves/short_weierstrass/mod.rs @@ -22,6 +22,7 @@ pub struct AffineGadget< > { pub x: F, pub y: F, + pub infinity: Boolean, _params: PhantomData

, _engine: PhantomData, } @@ -29,10 +30,11 @@ pub struct AffineGadget< impl> AffineGadget { - pub fn new(x: F, y: F) -> Self { + pub fn new(x: F, y: F, infinity: Boolean) -> Self { Self { x, y, + infinity, _params: PhantomData, _engine: PhantomData, } @@ -45,21 +47,23 @@ impl Result, SynthesisError>, { - let (x, y) = match value_gen() { - Ok(fe) => { - let fe = fe.into_affine(); - (Ok(fe.x), Ok(fe.y)) + let (x, y, infinity) = match value_gen() { + Ok(ge) => { + let ge = ge.into_affine(); + (Ok(ge.x), Ok(ge.y), Ok(ge.infinity)) }, _ => ( Err(SynthesisError::AssignmentMissing), Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), ), }; let x = F::alloc(&mut cs.ns(|| "x"), || x)?; let y = F::alloc(&mut cs.ns(|| "y"), || y)?; + let infinity = Boolean::alloc(&mut cs.ns(|| "infinity"), || infinity)?; - Ok(Self::new(x, y)) + Ok(Self::new(x, y, infinity)) } } @@ -94,12 +98,11 @@ where #[inline] fn get_value(&self) -> Option { - match (self.x.get_value(), self.y.get_value()) { - (Some(x), Some(y)) => { - let is_zero = x.is_zero() && y.is_one(); - Some(SWAffine::new(x, y, is_zero).into_projective()) + match (self.x.get_value(), self.y.get_value(), self.infinity.get_value()) { + (Some(x), Some(y), Some(infinity)) => { + Some(SWAffine::new(x, y, infinity).into_projective()) }, - (None, None) => None, + (None, None, None) => None, _ => unreachable!(), } } @@ -114,6 +117,7 @@ where Ok(Self::new( F::zero(cs.ns(|| "zero"))?, F::one(cs.ns(|| "one"))?, + Boolean::Constant(true), )) } @@ -140,6 +144,7 @@ where // // So we need to check that A.x - B.x != 0, which can be done by // enforcing I * (B.x - A.x) = 1 + // This is done below when we calculate inv (by F::inverse) let x2_minus_x1 = other.x.sub(cs.ns(|| "x2 - x1"), &self.x)?; let y2_minus_y1 = other.y.sub(cs.ns(|| "y2 - y1"), &self.y)?; @@ -180,7 +185,7 @@ where lambda.mul_equals(cs.ns(|| ""), &x1_minus_x3, &y3_plus_y1)?; - Ok(Self::new(x_3, y_3)) + Ok(Self::new(x_3, y_3, Boolean::Constant(false))) } /// Incomplete addition: neither `self` nor `other` can be the neutral @@ -257,7 +262,7 @@ where lambda.mul_equals(cs.ns(|| ""), &x1_minus_x3, &y3_plus_y1)?; - Ok(Self::new(x_3, y_3)) + Ok(Self::new(x_3, y_3, Boolean::Constant(false))) } #[inline] @@ -296,7 +301,7 @@ where .mul(cs.ns(|| "times lambda"), &lambda)? .sub(cs.ns(|| "plus self.y"), &self.y)?; - *self = Self::new(x, y); + *self = Self::new(x, y, Boolean::Constant(false)); Ok(()) } @@ -307,6 +312,7 @@ where Ok(Self::new( self.x.clone(), self.y.negate(cs.ns(|| "negate y"))?, + self.infinity, )) } @@ -334,12 +340,14 @@ where ) -> Result { let x = F::conditionally_select(&mut cs.ns(|| "x"), cond, &first.x, &second.x)?; let y = F::conditionally_select(&mut cs.ns(|| "y"), cond, &first.y, &second.y)?; + let infinity = Boolean::conditionally_select(&mut cs.ns(|| "infinity"), cond, &first.infinity, &second.infinity)?; - Ok(Self::new(x, y)) + Ok(Self::new(x, y, infinity)) } fn cost() -> usize { - 2 * >::cost() + 2 * >::cost() + + >::cost() } } @@ -374,6 +382,11 @@ where &other.y, condition, )?; + self.infinity.conditional_enforce_equal( + &mut cs.ns(|| "Infinity Conditional Equality"), + &other.infinity, + condition, + )?; Ok(()) } @@ -422,14 +435,15 @@ where FN: FnOnce() -> Result, T: Borrow>, { - let (x, y) = match value_gen() { + let (x, y, infinity) = match value_gen() { Ok(ge) => { let ge = ge.borrow().into_affine(); - (Ok(ge.x), Ok(ge.y)) + (Ok(ge.x), Ok(ge.y), Ok(ge.infinity)) }, _ => ( Err(SynthesisError::AssignmentMissing), Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), ), }; @@ -439,6 +453,7 @@ where let x = F::alloc(&mut cs.ns(|| "x"), || x)?; let y = F::alloc(&mut cs.ns(|| "y"), || y)?; + let infinity = Boolean::alloc(&mut cs.ns(|| "infinity"), || infinity)?; // Check that y^2 = x^3 + ax +b // We do this by checking that y^2 - b = x * (x^2 +a) @@ -450,7 +465,7 @@ where x2_plus_a.mul_equals(cs.ns(|| "on curve check"), &x, &y2_minus_b)?; - Ok(Self::new(x, y)) + Ok(Self::new(x, y, infinity)) } #[inline] @@ -542,34 +557,25 @@ where FN: FnOnce() -> Result, T: Borrow>, { - let (x, y) = match value_gen() { + // When allocating the input we assume that the verifier has performed + // any on curve checks already. + let (x, y, infinity) = match value_gen() { Ok(ge) => { let ge = ge.borrow().into_affine(); - (Ok(ge.x), Ok(ge.y)) + (Ok(ge.x), Ok(ge.y), Ok(ge.infinity)) }, _ => ( Err(SynthesisError::AssignmentMissing), Err(SynthesisError::AssignmentMissing), + Err(SynthesisError::AssignmentMissing), ), }; - let b = P::COEFF_B; - let a = P::COEFF_A; - let x = F::alloc_input(&mut cs.ns(|| "x"), || x)?; let y = F::alloc_input(&mut cs.ns(|| "y"), || y)?; + let infinity = Boolean::alloc_input(&mut cs.ns(|| "infinity"), || infinity)?; - // Check that y^2 = x^3 + ax +b - // We do this by checking that y^2 - b = x * (x^2 +a) - let x2 = x.square(&mut cs.ns(|| "x^2"))?; - let y2 = y.square(&mut cs.ns(|| "y^2"))?; - - let x2_plus_a = x2.add_constant(cs.ns(|| "x^2 + a"), &a)?; - let y2_minus_b = y2.add_constant(cs.ns(|| "y^2 - b"), &b.neg())?; - - x2_plus_a.mul_equals(cs.ns(|| "on curve check"), &x, &y2_minus_b)?; - - Ok(Self::new(x, y)) + Ok(Self::new(x, y, infinity)) } } @@ -586,6 +592,7 @@ where let mut x_bits = self.x.to_bits(&mut cs.ns(|| "X Coordinate To Bits"))?; let y_bits = self.y.to_bits(&mut cs.ns(|| "Y Coordinate To Bits"))?; x_bits.extend_from_slice(&y_bits); + x_bits.push(self.infinity); Ok(x_bits) } @@ -600,6 +607,7 @@ where .y .to_bits_strict(&mut cs.ns(|| "Y Coordinate To Bits"))?; x_bits.extend_from_slice(&y_bits); + x_bits.push(self.infinity); Ok(x_bits) } @@ -617,7 +625,9 @@ where ) -> Result, SynthesisError> { let mut x_bytes = self.x.to_bytes(&mut cs.ns(|| "X Coordinate To Bytes"))?; let y_bytes = self.y.to_bytes(&mut cs.ns(|| "Y Coordinate To Bytes"))?; + let inf_bytes = self.infinity.to_bytes(&mut cs.ns(|| "Infinity to Bytes"))?; x_bytes.extend_from_slice(&y_bytes); + x_bytes.extend_from_slice(&inf_bytes); Ok(x_bytes) } @@ -631,7 +641,9 @@ where let y_bytes = self .y .to_bytes_strict(&mut cs.ns(|| "Y Coordinate To Bytes"))?; + let inf_bytes = self.infinity.to_bytes(&mut cs.ns(|| "Infinity to Bytes"))?; x_bytes.extend_from_slice(&y_bytes); + x_bytes.extend_from_slice(&inf_bytes); Ok(x_bytes) }