Boolean conditional select

This commit is contained in:
Pratyush Mishra
2019-12-12 01:40:41 -08:00
parent ab65b01478
commit 2a8974e442
2 changed files with 231 additions and 68 deletions

View File

@@ -1,7 +1,7 @@
use algebra::{BitIterator, Field, FpParameters, PrimeField};
use crate::{prelude::*, Assignment};
use r1cs_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable};
use r1cs_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable, ConstraintVar};
use std::borrow::Borrow;
/// Represents a variable in the constraint system which is guaranteed
@@ -114,12 +114,39 @@ impl AllocatedBit {
/// Performs an OR operation over the two operands, returning
/// an `AllocatedBit`.
pub fn or<ConstraintF, CS>(cs: CS, a: &Self, b: &Self) -> Result<Self, SynthesisError>
pub fn or<ConstraintF, CS>(mut cs: CS, a: &Self, b: &Self) -> Result<Self, SynthesisError>
where
ConstraintF: Field,
CS: ConstraintSystem<ConstraintF>,
{
Self::conditionally_select(cs, &Boolean::from(*a), a, b)
let mut result_value = None;
let result_var = cs.alloc(
|| "or result",
|| {
if a.value.get()? | b.value.get()? {
result_value = Some(true);
Ok(ConstraintF::one())
} else {
result_value = Some(false);
Ok(ConstraintF::zero())
}
},
)?;
// Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff
// a and b are both false, and otherwise c is 0.
cs.enforce(
|| "nor constraint",
|lc| lc + CS::one() - a.variable,
|lc| lc + CS::one() - b.variable,
|lc| lc + CS::one() - result_var,
);
Ok(AllocatedBit {
variable: result_var,
value: result_value,
})
}
/// Calculates `a AND (NOT b)`.
@@ -283,40 +310,17 @@ impl<ConstraintF: Field> AllocGadget<bool, ConstraintF> for AllocatedBit {
impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for AllocatedBit {
fn conditionally_select<CS: ConstraintSystem<ConstraintF>>(
mut cs: CS,
cs: CS,
cond: &Boolean,
first: &Self,
second: &Self,
) -> Result<Self, SynthesisError> {
let result = Self::alloc(cs.ns(|| ""), || {
cond.get_value()
.and_then(|cond| {
{
if cond {
first
} else {
second
}
}
.get_value()
})
.get()
})?;
// a = self; b = other; c = cond;
//
// r = c * a + (1 - c) * b
// r = b + c * (a - b)
// c * (a - b) = r - b
let one = CS::one();
cs.enforce(
|| "conditionally_select",
|_| cond.lc(one, ConstraintF::one()),
|lc| lc + first.variable - second.variable,
|lc| lc + result.variable - second.variable,
);
Ok(result)
cond_select_helper(
cs,
cond,
(first.value, first.variable),
(second.value, second.variable),
)
}
fn cost() -> usize {
@@ -324,6 +328,40 @@ impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for AllocatedBit {
}
}
fn cond_select_helper<F: Field, CS: ConstraintSystem<F>>(
mut cs: CS,
cond: &Boolean,
first: (Option<bool>, impl Into<ConstraintVar<F>>),
second: (Option<bool>, impl Into<ConstraintVar<F>>),
) -> Result<AllocatedBit, SynthesisError> {
let mut result_val = None;
let result_var = cs.alloc(
|| "cond_select_result",
|| {
result_val = cond.get_value().and_then(|c| if c { first.0 } else { second.0 });
result_val.get().map(|v| F::from(v as u8))
})?;
let first_var = first.1.into();
let second_var = second.1.into();
// a = self; b = other; c = cond;
//
// r = c * a + (1 - c) * b
// r = b + c * (a - b)
// c * (a - b) = r - b
let one = CS::one();
cs.enforce(
|| "conditionally_select",
|_| cond.lc(one, F::one()),
|lc| (&first_var - &second_var) + lc,
|lc| ConstraintVar::from(result_var) - &second_var + lc,
);
Ok(AllocatedBit { value: result_val, variable: result_var })
}
/// This is a boolean value which may be either a constant or
/// an interpretation of an `AllocatedBit`.
#[derive(Copy, Clone, Debug)]
@@ -747,6 +785,54 @@ impl<ConstraintF: Field> ToBytesGadget<ConstraintF> for Boolean {
self.to_bytes(cs)
}
}
impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for Boolean {
fn conditionally_select<CS>(
mut cs: CS,
cond: &Self,
first: &Self,
second: &Self,
) -> Result<Self, SynthesisError>
where
CS: ConstraintSystem<ConstraintF>,
{
match cond {
Boolean::Constant(true) => Ok(first.clone()),
Boolean::Constant(false) => Ok(second.clone()),
cond @ Boolean::Not(_) => Self::conditionally_select(cs, &cond.not(), second, first),
cond @ Boolean::Is(_) => {
match (first, second) {
(x, &Boolean::Constant(false)) => {
Boolean::and(cs.ns(|| "and"), cond, x).into()
},
(&Boolean::Constant(false), x) => {
Boolean::and(cs.ns(|| "and"), &cond.not(), x)
},
(&Boolean::Constant(true), x) => {
Boolean::or(cs.ns(|| "or"), cond, x).into()
},
(x, &Boolean::Constant(true)) => {
Boolean::or(cs.ns(|| "or"), &cond.not(), x)
},
(a @ Boolean::Is(_), b @ Boolean::Is(_))
| (a @ Boolean::Not(_), b @ Boolean::Not(_))
| (a @ Boolean::Is(_), b @ Boolean::Not(_))
| (a @ Boolean::Not(_), b @ Boolean::Is(_)) => {
let a_lc = a.lc(CS::one(), ConstraintF::one());
let b_lc = b.lc(CS::one(), ConstraintF::one());
Ok(cond_select_helper(cs, cond, (a.get_value(), a_lc), (b.get_value(), b_lc))?.into())
},
}
}
}
}
fn cost() -> usize {
1
}
}
#[cfg(test)]
mod test {
@@ -1288,6 +1374,71 @@ mod test {
}
}
#[test]
fn test_boolean_cond_select() {
let variants = [
OperandType::True,
OperandType::False,
OperandType::AllocatedTrue,
OperandType::AllocatedFalse,
OperandType::NegatedAllocatedTrue,
OperandType::NegatedAllocatedFalse,
];
for condition in variants.iter().cloned() {
for first_operand in variants.iter().cloned() {
for second_operand in variants.iter().cloned() {
let mut cs = TestConstraintSystem::<Fr>::new();
let cond;
let a;
let b;
{
let mut dyn_construct = |operand, name| {
let cs = cs.ns(|| name);
match operand {
OperandType::True => Boolean::constant(true),
OperandType::False => Boolean::constant(false),
OperandType::AllocatedTrue => {
Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap())
},
OperandType::AllocatedFalse => {
Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap())
},
OperandType::NegatedAllocatedTrue => {
Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()).not()
},
OperandType::NegatedAllocatedFalse => {
Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()).not()
},
}
};
cond = dyn_construct(condition, "cond");
a = dyn_construct(first_operand, "a");
b = dyn_construct(second_operand, "b");
}
let before = cs.num_constraints();
let c = Boolean::conditionally_select(&mut cs, &cond, &a, &b).unwrap();
let after = cs.num_constraints();
assert!(
cs.is_satisfied(),
"failed with operands: cond: {:?}, a: {:?}, b: {:?}",
condition,
first_operand,
second_operand,
);
assert_eq!(c.get_value(), if cond.get_value().unwrap() { a.get_value() } else { b.get_value() });
assert!(<Boolean as CondSelectGadget<Fr>>::cost() >= after - before);
}
}
}
}
#[test]
fn test_boolean_or() {
let variants = [