mirror of
https://github.com/arnaucube/ark-r1cs-std.git
synced 2026-01-23 12:13:48 +01:00
Boolean conditional select
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use algebra::{BitIterator, Field, FpParameters, PrimeField};
|
||||
|
||||
use crate::{prelude::*, Assignment};
|
||||
use r1cs_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable};
|
||||
use r1cs_core::{ConstraintSystem, LinearCombination, SynthesisError, Variable, ConstraintVar};
|
||||
use std::borrow::Borrow;
|
||||
|
||||
/// Represents a variable in the constraint system which is guaranteed
|
||||
@@ -114,12 +114,39 @@ impl AllocatedBit {
|
||||
|
||||
/// Performs an OR operation over the two operands, returning
|
||||
/// an `AllocatedBit`.
|
||||
pub fn or<ConstraintF, CS>(cs: CS, a: &Self, b: &Self) -> Result<Self, SynthesisError>
|
||||
pub fn or<ConstraintF, CS>(mut cs: CS, a: &Self, b: &Self) -> Result<Self, SynthesisError>
|
||||
where
|
||||
ConstraintF: Field,
|
||||
CS: ConstraintSystem<ConstraintF>,
|
||||
{
|
||||
Self::conditionally_select(cs, &Boolean::from(*a), a, b)
|
||||
let mut result_value = None;
|
||||
|
||||
let result_var = cs.alloc(
|
||||
|| "or result",
|
||||
|| {
|
||||
if a.value.get()? | b.value.get()? {
|
||||
result_value = Some(true);
|
||||
Ok(ConstraintF::one())
|
||||
} else {
|
||||
result_value = Some(false);
|
||||
Ok(ConstraintF::zero())
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
// Constrain (1 - a) * (1 - b) = (c), ensuring c is 1 iff
|
||||
// a and b are both false, and otherwise c is 0.
|
||||
cs.enforce(
|
||||
|| "nor constraint",
|
||||
|lc| lc + CS::one() - a.variable,
|
||||
|lc| lc + CS::one() - b.variable,
|
||||
|lc| lc + CS::one() - result_var,
|
||||
);
|
||||
|
||||
Ok(AllocatedBit {
|
||||
variable: result_var,
|
||||
value: result_value,
|
||||
})
|
||||
}
|
||||
|
||||
/// Calculates `a AND (NOT b)`.
|
||||
@@ -283,40 +310,17 @@ impl<ConstraintF: Field> AllocGadget<bool, ConstraintF> for AllocatedBit {
|
||||
|
||||
impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for AllocatedBit {
|
||||
fn conditionally_select<CS: ConstraintSystem<ConstraintF>>(
|
||||
mut cs: CS,
|
||||
cs: CS,
|
||||
cond: &Boolean,
|
||||
first: &Self,
|
||||
second: &Self,
|
||||
) -> Result<Self, SynthesisError> {
|
||||
let result = Self::alloc(cs.ns(|| ""), || {
|
||||
cond.get_value()
|
||||
.and_then(|cond| {
|
||||
{
|
||||
if cond {
|
||||
first
|
||||
} else {
|
||||
second
|
||||
}
|
||||
}
|
||||
.get_value()
|
||||
})
|
||||
.get()
|
||||
})?;
|
||||
|
||||
// a = self; b = other; c = cond;
|
||||
//
|
||||
// r = c * a + (1 - c) * b
|
||||
// r = b + c * (a - b)
|
||||
// c * (a - b) = r - b
|
||||
let one = CS::one();
|
||||
cs.enforce(
|
||||
|| "conditionally_select",
|
||||
|_| cond.lc(one, ConstraintF::one()),
|
||||
|lc| lc + first.variable - second.variable,
|
||||
|lc| lc + result.variable - second.variable,
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
cond_select_helper(
|
||||
cs,
|
||||
cond,
|
||||
(first.value, first.variable),
|
||||
(second.value, second.variable),
|
||||
)
|
||||
}
|
||||
|
||||
fn cost() -> usize {
|
||||
@@ -324,6 +328,40 @@ impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for AllocatedBit {
|
||||
}
|
||||
}
|
||||
|
||||
fn cond_select_helper<F: Field, CS: ConstraintSystem<F>>(
|
||||
mut cs: CS,
|
||||
cond: &Boolean,
|
||||
first: (Option<bool>, impl Into<ConstraintVar<F>>),
|
||||
second: (Option<bool>, impl Into<ConstraintVar<F>>),
|
||||
) -> Result<AllocatedBit, SynthesisError> {
|
||||
let mut result_val = None;
|
||||
let result_var = cs.alloc(
|
||||
|| "cond_select_result",
|
||||
|| {
|
||||
result_val = cond.get_value().and_then(|c| if c { first.0 } else { second.0 });
|
||||
result_val.get().map(|v| F::from(v as u8))
|
||||
})?;
|
||||
|
||||
let first_var = first.1.into();
|
||||
let second_var = second.1.into();
|
||||
|
||||
// a = self; b = other; c = cond;
|
||||
//
|
||||
// r = c * a + (1 - c) * b
|
||||
// r = b + c * (a - b)
|
||||
// c * (a - b) = r - b
|
||||
let one = CS::one();
|
||||
cs.enforce(
|
||||
|| "conditionally_select",
|
||||
|_| cond.lc(one, F::one()),
|
||||
|lc| (&first_var - &second_var) + lc,
|
||||
|lc| ConstraintVar::from(result_var) - &second_var + lc,
|
||||
);
|
||||
|
||||
Ok(AllocatedBit { value: result_val, variable: result_var })
|
||||
|
||||
}
|
||||
|
||||
/// This is a boolean value which may be either a constant or
|
||||
/// an interpretation of an `AllocatedBit`.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
@@ -747,6 +785,54 @@ impl<ConstraintF: Field> ToBytesGadget<ConstraintF> for Boolean {
|
||||
self.to_bytes(cs)
|
||||
}
|
||||
}
|
||||
impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for Boolean {
|
||||
fn conditionally_select<CS>(
|
||||
mut cs: CS,
|
||||
cond: &Self,
|
||||
first: &Self,
|
||||
second: &Self,
|
||||
) -> Result<Self, SynthesisError>
|
||||
where
|
||||
CS: ConstraintSystem<ConstraintF>,
|
||||
{
|
||||
match cond {
|
||||
Boolean::Constant(true) => Ok(first.clone()),
|
||||
Boolean::Constant(false) => Ok(second.clone()),
|
||||
cond @ Boolean::Not(_) => Self::conditionally_select(cs, &cond.not(), second, first),
|
||||
cond @ Boolean::Is(_) => {
|
||||
match (first, second) {
|
||||
(x, &Boolean::Constant(false)) => {
|
||||
Boolean::and(cs.ns(|| "and"), cond, x).into()
|
||||
},
|
||||
(&Boolean::Constant(false), x) => {
|
||||
Boolean::and(cs.ns(|| "and"), &cond.not(), x)
|
||||
},
|
||||
(&Boolean::Constant(true), x) => {
|
||||
Boolean::or(cs.ns(|| "or"), cond, x).into()
|
||||
},
|
||||
(x, &Boolean::Constant(true)) => {
|
||||
Boolean::or(cs.ns(|| "or"), &cond.not(), x)
|
||||
},
|
||||
(a @ Boolean::Is(_), b @ Boolean::Is(_))
|
||||
| (a @ Boolean::Not(_), b @ Boolean::Not(_))
|
||||
| (a @ Boolean::Is(_), b @ Boolean::Not(_))
|
||||
| (a @ Boolean::Not(_), b @ Boolean::Is(_)) => {
|
||||
let a_lc = a.lc(CS::one(), ConstraintF::one());
|
||||
let b_lc = b.lc(CS::one(), ConstraintF::one());
|
||||
Ok(cond_select_helper(cs, cond, (a.get_value(), a_lc), (b.get_value(), b_lc))?.into())
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cost() -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
@@ -1288,6 +1374,71 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boolean_cond_select() {
|
||||
let variants = [
|
||||
OperandType::True,
|
||||
OperandType::False,
|
||||
OperandType::AllocatedTrue,
|
||||
OperandType::AllocatedFalse,
|
||||
OperandType::NegatedAllocatedTrue,
|
||||
OperandType::NegatedAllocatedFalse,
|
||||
];
|
||||
|
||||
for condition in variants.iter().cloned() {
|
||||
for first_operand in variants.iter().cloned() {
|
||||
for second_operand in variants.iter().cloned() {
|
||||
let mut cs = TestConstraintSystem::<Fr>::new();
|
||||
|
||||
let cond;
|
||||
let a;
|
||||
let b;
|
||||
|
||||
{
|
||||
let mut dyn_construct = |operand, name| {
|
||||
let cs = cs.ns(|| name);
|
||||
|
||||
match operand {
|
||||
OperandType::True => Boolean::constant(true),
|
||||
OperandType::False => Boolean::constant(false),
|
||||
OperandType::AllocatedTrue => {
|
||||
Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap())
|
||||
},
|
||||
OperandType::AllocatedFalse => {
|
||||
Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap())
|
||||
},
|
||||
OperandType::NegatedAllocatedTrue => {
|
||||
Boolean::from(AllocatedBit::alloc(cs, || Ok(true)).unwrap()).not()
|
||||
},
|
||||
OperandType::NegatedAllocatedFalse => {
|
||||
Boolean::from(AllocatedBit::alloc(cs, || Ok(false)).unwrap()).not()
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
cond = dyn_construct(condition, "cond");
|
||||
a = dyn_construct(first_operand, "a");
|
||||
b = dyn_construct(second_operand, "b");
|
||||
}
|
||||
|
||||
let before = cs.num_constraints();
|
||||
let c = Boolean::conditionally_select(&mut cs, &cond, &a, &b).unwrap();
|
||||
let after = cs.num_constraints();
|
||||
|
||||
assert!(
|
||||
cs.is_satisfied(),
|
||||
"failed with operands: cond: {:?}, a: {:?}, b: {:?}",
|
||||
condition,
|
||||
first_operand,
|
||||
second_operand,
|
||||
);
|
||||
assert_eq!(c.get_value(), if cond.get_value().unwrap() { a.get_value() } else { b.get_value() });
|
||||
assert!(<Boolean as CondSelectGadget<Fr>>::cost() >= after - before);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boolean_or() {
|
||||
let variants = [
|
||||
|
||||
Reference in New Issue
Block a user