used AllocGadget for UInt64

This commit is contained in:
weikeng
2020-03-26 13:22:26 -07:00
committed by Pratyush Mishra
parent 0ffa409ec1
commit a205f191f7

View File

@@ -7,6 +7,7 @@ use crate::{
prelude::*, prelude::*,
Assignment, Vec, Assignment, Vec,
}; };
use core::borrow::Borrow;
/// Represents an interpretation of 64 `Boolean` objects as an /// Represents an interpretation of 64 `Boolean` objects as an
/// unsigned integer. /// unsigned integer.
@@ -40,7 +41,7 @@ impl UInt64 {
} }
/// Allocate a `UInt64` in the constraint system /// Allocate a `UInt64` in the constraint system
pub fn alloc<ConstraintF, CS>(mut cs: CS, value: Option<u64>) -> Result<Self, SynthesisError> pub fn _alloc<ConstraintF, CS>(mut cs: CS, value: Option<u64>) -> Result<Self, SynthesisError>
where where
ConstraintF: Field, ConstraintF: Field,
CS: ConstraintSystem<ConstraintF>, CS: ConstraintSystem<ConstraintF>,
@@ -55,7 +56,7 @@ impl UInt64 {
} }
v v
}, }
None => vec![None; 64], None => vec![None; 64],
}; };
@@ -94,19 +95,19 @@ impl UInt64 {
if b { if b {
value.as_mut().map(|v| *v |= 1); value.as_mut().map(|v| *v |= 1);
} }
}, }
&Boolean::Is(ref b) => match b.get_value() { &Boolean::Is(ref b) => match b.get_value() {
Some(true) => { Some(true) => {
value.as_mut().map(|v| *v |= 1); value.as_mut().map(|v| *v |= 1);
}, }
Some(false) => {}, Some(false) => {}
None => value = None, None => value = None,
}, },
&Boolean::Not(ref b) => match b.get_value() { &Boolean::Not(ref b) => match b.get_value() {
Some(false) => { Some(false) => {
value.as_mut().map(|v| *v |= 1); value.as_mut().map(|v| *v |= 1);
}, }
Some(true) => {}, Some(true) => {}
None => value = None, None => value = None,
}, },
} }
@@ -193,12 +194,12 @@ impl UInt64 {
match op.value { match op.value {
Some(val) => { Some(val) => {
result_value.as_mut().map(|v| *v += u128::from(val)); result_value.as_mut().map(|v| *v += u128::from(val));
}, }
None => { None => {
// If any of our operands have unknown value, we won't // If any of our operands have unknown value, we won't
// know the value of the result // know the value of the result
result_value = None; result_value = None;
}, }
} }
// Iterate over each bit_gadget of the operand and add the operand to // Iterate over each bit_gadget of the operand and add the operand to
@@ -211,18 +212,18 @@ impl UInt64 {
// Add coeff * bit_gadget // Add coeff * bit_gadget
lc += (coeff, bit.get_variable()); lc += (coeff, bit.get_variable());
}, }
Boolean::Not(ref bit) => { Boolean::Not(ref bit) => {
all_constants = false; all_constants = false;
// Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * bit_gadget // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * bit_gadget
lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable()); lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable());
}, }
Boolean::Constant(bit) => { Boolean::Constant(bit) => {
if bit { if bit {
lc += (coeff, CS::one()); lc += (coeff, CS::one());
} }
}, }
} }
coeff.double_in_place(); coeff.double_in_place();
@@ -275,6 +276,33 @@ impl UInt64 {
} }
} }
impl<ConstraintF: Field> AllocGadget<u64, ConstraintF> for UInt64 {
fn alloc<F, T, CS: ConstraintSystem<ConstraintF>>(
mut cs: CS,
value_gen: F,
) -> Result<Self, SynthesisError>
where
F: FnOnce() -> Result<T, SynthesisError>,
T: Borrow<u64>,
{
let val = value_gen()?.borrow().clone();
Self::_alloc(&mut cs.ns(|| "alloc u64"), Some(val))
}
fn alloc_input<F, T, CS: ConstraintSystem<ConstraintF>>(
mut cs: CS,
value_gen: F,
) -> Result<Self, SynthesisError>
where
F: FnOnce() -> Result<T, SynthesisError>,
T: Borrow<u64>,
{
let val = value_gen()?.borrow().clone();
Self::_alloc(&mut cs.ns(|| "alloc u64"), Some(val))
}
}
impl<ConstraintF: Field> ToBytesGadget<ConstraintF> for UInt64 { impl<ConstraintF: Field> ToBytesGadget<ConstraintF> for UInt64 {
#[inline] #[inline]
fn to_bytes<CS: ConstraintSystem<ConstraintF>>( fn to_bytes<CS: ConstraintSystem<ConstraintF>>(
@@ -310,6 +338,13 @@ impl<ConstraintF: Field> ToBytesGadget<ConstraintF> for UInt64 {
Ok(bytes) Ok(bytes)
} }
fn to_bytes_strict<CS: ConstraintSystem<ConstraintF>>(
&self,
cs: CS,
) -> Result<Vec<UInt8>, SynthesisError> {
self.to_bytes(cs)
}
} }
impl PartialEq for UInt64 { impl PartialEq for UInt64 {
@@ -345,7 +380,10 @@ impl<ConstraintF: Field> ConditionalEqGadget<ConstraintF> for UInt64 {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::UInt64; use super::UInt64;
use crate::{bits::boolean::Boolean, test_constraint_system::TestConstraintSystem, Vec}; use crate::{
alloc::AllocGadget, bits::boolean::Boolean, test_constraint_system::TestConstraintSystem,
Vec,
};
use algebra::{bls12_381::Fr, One, Zero}; use algebra::{bls12_381::Fr, One, Zero};
use r1cs_core::ConstraintSystem; use r1cs_core::ConstraintSystem;
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
@@ -366,7 +404,7 @@ mod test {
match bit_gadget { match bit_gadget {
&Boolean::Constant(bit_gadget) => { &Boolean::Constant(bit_gadget) => {
assert!(bit_gadget == ((b.value.unwrap() >> i) & 1 == 1)); assert!(bit_gadget == ((b.value.unwrap() >> i) & 1 == 1));
}, }
_ => unreachable!(), _ => unreachable!(),
} }
} }
@@ -375,8 +413,8 @@ mod test {
for x in v.iter().zip(expected_to_be_same.iter()) { for x in v.iter().zip(expected_to_be_same.iter()) {
match x { match x {
(&Boolean::Constant(true), &Boolean::Constant(true)) => {}, (&Boolean::Constant(true), &Boolean::Constant(true)) => {}
(&Boolean::Constant(false), &Boolean::Constant(false)) => {}, (&Boolean::Constant(false), &Boolean::Constant(false)) => {}
_ => unreachable!(), _ => unreachable!(),
} }
} }
@@ -396,9 +434,9 @@ mod test {
let mut expected = a ^ b ^ c; let mut expected = a ^ b ^ c;
let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();
let b_bit = UInt64::constant(b); let b_bit = UInt64::constant(b);
let c_bit = UInt64::alloc(cs.ns(|| "c_bit"), Some(c)).unwrap(); let c_bit = UInt64::alloc(cs.ns(|| "c_bit"), || Ok(c)).unwrap();
let r = a_bit.xor(cs.ns(|| "first xor"), &b_bit).unwrap(); let r = a_bit.xor(cs.ns(|| "first xor"), &b_bit).unwrap();
let r = r.xor(cs.ns(|| "second xor"), &c_bit).unwrap(); let r = r.xor(cs.ns(|| "second xor"), &c_bit).unwrap();
@@ -411,13 +449,13 @@ mod test {
match b { match b {
&Boolean::Is(ref b) => { &Boolean::Is(ref b) => {
assert!(b.get_value().unwrap() == (expected & 1 == 1)); assert!(b.get_value().unwrap() == (expected & 1 == 1));
}, }
&Boolean::Not(ref b) => { &Boolean::Not(ref b) => {
assert!(!b.get_value().unwrap() == (expected & 1 == 1)); assert!(!b.get_value().unwrap() == (expected & 1 == 1));
}, }
&Boolean::Constant(b) => { &Boolean::Constant(b) => {
assert!(b == (expected & 1 == 1)); assert!(b == (expected & 1 == 1));
}, }
} }
expected >>= 1; expected >>= 1;
@@ -452,7 +490,7 @@ mod test {
&Boolean::Not(_) => panic!(), &Boolean::Not(_) => panic!(),
&Boolean::Constant(b) => { &Boolean::Constant(b) => {
assert!(b == (expected & 1 == 1)); assert!(b == (expected & 1 == 1));
}, }
} }
expected >>= 1; expected >>= 1;
@@ -474,10 +512,10 @@ mod test {
let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d);
let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();
let b_bit = UInt64::constant(b); let b_bit = UInt64::constant(b);
let c_bit = UInt64::constant(c); let c_bit = UInt64::constant(c);
let d_bit = UInt64::alloc(cs.ns(|| "d_bit"), Some(d)).unwrap(); let d_bit = UInt64::alloc(cs.ns(|| "d_bit"), || Ok(d)).unwrap();
let r = a_bit.xor(cs.ns(|| "xor"), &b_bit).unwrap(); let r = a_bit.xor(cs.ns(|| "xor"), &b_bit).unwrap();
let r = UInt64::addmany(cs.ns(|| "addition"), &[r, c_bit, d_bit]).unwrap(); let r = UInt64::addmany(cs.ns(|| "addition"), &[r, c_bit, d_bit]).unwrap();
@@ -490,10 +528,10 @@ mod test {
match b { match b {
&Boolean::Is(ref b) => { &Boolean::Is(ref b) => {
assert!(b.get_value().unwrap() == (expected & 1 == 1)); assert!(b.get_value().unwrap() == (expected & 1 == 1));
}, }
&Boolean::Not(ref b) => { &Boolean::Not(ref b) => {
assert!(!b.get_value().unwrap() == (expected & 1 == 1)); assert!(!b.get_value().unwrap() == (expected & 1 == 1));
}, }
&Boolean::Constant(_) => unreachable!(), &Boolean::Constant(_) => unreachable!(),
} }
@@ -529,7 +567,7 @@ mod test {
match b { match b {
&Boolean::Constant(b) => { &Boolean::Constant(b) => {
assert_eq!(b, tmp & 1 == 1); assert_eq!(b, tmp & 1 == 1);
}, }
_ => unreachable!(), _ => unreachable!(),
} }