From a205f191f73c3101e9ab49028d159fd8e41ec461 Mon Sep 17 00:00:00 2001 From: weikeng Date: Thu, 26 Mar 2020 13:22:26 -0700 Subject: [PATCH] used AllocGadget for UInt64 --- r1cs-std/src/bits/uint64.rs | 100 +++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 31 deletions(-) diff --git a/r1cs-std/src/bits/uint64.rs b/r1cs-std/src/bits/uint64.rs index 50a168d..83bd286 100644 --- a/r1cs-std/src/bits/uint64.rs +++ b/r1cs-std/src/bits/uint64.rs @@ -7,13 +7,14 @@ use crate::{ prelude::*, Assignment, Vec, }; +use core::borrow::Borrow; /// Represents an interpretation of 64 `Boolean` objects as an /// unsigned integer. #[derive(Clone, Debug)] pub struct UInt64 { // Least significant bit_gadget first - bits: Vec, + bits: Vec, value: Option, } @@ -40,7 +41,7 @@ impl UInt64 { } /// Allocate a `UInt64` in the constraint system - pub fn alloc(mut cs: CS, value: Option) -> Result + pub fn _alloc(mut cs: CS, value: Option) -> Result where ConstraintF: Field, CS: ConstraintSystem, @@ -55,7 +56,7 @@ impl UInt64 { } v - }, + } None => vec![None; 64], }; @@ -94,19 +95,19 @@ impl UInt64 { if b { value.as_mut().map(|v| *v |= 1); } - }, + } &Boolean::Is(ref b) => match b.get_value() { Some(true) => { value.as_mut().map(|v| *v |= 1); - }, - Some(false) => {}, + } + Some(false) => {} None => value = None, }, &Boolean::Not(ref b) => match b.get_value() { Some(false) => { value.as_mut().map(|v| *v |= 1); - }, - Some(true) => {}, + } + Some(true) => {} None => value = None, }, } @@ -128,7 +129,7 @@ impl UInt64 { .collect(); UInt64 { - bits: new_bits, + bits: new_bits, value: self.value.map(|v| v.rotate_right(by as u32)), } } @@ -193,12 +194,12 @@ impl UInt64 { match op.value { Some(val) => { result_value.as_mut().map(|v| *v += u128::from(val)); - }, + } None => { // If any of our operands have unknown value, we won't // know the value of the result result_value = None; - }, + } } // Iterate over each bit_gadget of the operand and add the operand to @@ -211,18 +212,18 @@ impl UInt64 { // Add coeff * bit_gadget lc += (coeff, bit.get_variable()); - }, + } Boolean::Not(ref bit) => { all_constants = false; // Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * bit_gadget lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable()); - }, + } Boolean::Constant(bit) => { if bit { lc += (coeff, CS::one()); } - }, + } } coeff.double_in_place(); @@ -269,12 +270,39 @@ impl UInt64 { result_bits.truncate(64); Ok(UInt64 { - bits: result_bits, + bits: result_bits, value: modular_value, }) } } +impl AllocGadget for UInt64 { + fn alloc>( + mut cs: CS, + value_gen: F, + ) -> Result + where + F: FnOnce() -> Result, + T: Borrow, + { + let val = value_gen()?.borrow().clone(); + + Self::_alloc(&mut cs.ns(|| "alloc u64"), Some(val)) + } + + fn alloc_input>( + mut cs: CS, + value_gen: F, + ) -> Result + where + F: FnOnce() -> Result, + T: Borrow, + { + let val = value_gen()?.borrow().clone(); + Self::_alloc(&mut cs.ns(|| "alloc u64"), Some(val)) + } +} + impl ToBytesGadget for UInt64 { #[inline] fn to_bytes>( @@ -302,7 +330,7 @@ impl ToBytesGadget for UInt64 { let mut bytes = Vec::new(); for (i, chunk8) in self.to_bits_le().chunks(8).enumerate() { let byte = UInt8 { - bits: chunk8.to_vec(), + bits: chunk8.to_vec(), value: value_chunks[i], }; bytes.push(byte); @@ -310,6 +338,13 @@ impl ToBytesGadget for UInt64 { Ok(bytes) } + + fn to_bytes_strict>( + &self, + cs: CS, + ) -> Result, SynthesisError> { + self.to_bytes(cs) + } } impl PartialEq for UInt64 { @@ -345,7 +380,10 @@ impl ConditionalEqGadget for UInt64 { #[cfg(test)] mod test { use super::UInt64; - use crate::{bits::boolean::Boolean, test_constraint_system::TestConstraintSystem, Vec}; + use crate::{ + alloc::AllocGadget, bits::boolean::Boolean, test_constraint_system::TestConstraintSystem, + Vec, + }; use algebra::{bls12_381::Fr, One, Zero}; use r1cs_core::ConstraintSystem; use rand::{Rng, SeedableRng}; @@ -366,7 +404,7 @@ mod test { match bit_gadget { &Boolean::Constant(bit_gadget) => { assert!(bit_gadget == ((b.value.unwrap() >> i) & 1 == 1)); - }, + } _ => unreachable!(), } } @@ -375,8 +413,8 @@ mod test { for x in v.iter().zip(expected_to_be_same.iter()) { match x { - (&Boolean::Constant(true), &Boolean::Constant(true)) => {}, - (&Boolean::Constant(false), &Boolean::Constant(false)) => {}, + (&Boolean::Constant(true), &Boolean::Constant(true)) => {} + (&Boolean::Constant(false), &Boolean::Constant(false)) => {} _ => unreachable!(), } } @@ -396,9 +434,9 @@ mod test { let mut expected = a ^ b ^ c; - let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); + let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap(); let b_bit = UInt64::constant(b); - let c_bit = UInt64::alloc(cs.ns(|| "c_bit"), Some(c)).unwrap(); + let c_bit = UInt64::alloc(cs.ns(|| "c_bit"), || Ok(c)).unwrap(); let r = a_bit.xor(cs.ns(|| "first xor"), &b_bit).unwrap(); let r = r.xor(cs.ns(|| "second xor"), &c_bit).unwrap(); @@ -411,13 +449,13 @@ mod test { match b { &Boolean::Is(ref b) => { assert!(b.get_value().unwrap() == (expected & 1 == 1)); - }, + } &Boolean::Not(ref b) => { assert!(!b.get_value().unwrap() == (expected & 1 == 1)); - }, + } &Boolean::Constant(b) => { assert!(b == (expected & 1 == 1)); - }, + } } expected >>= 1; @@ -452,7 +490,7 @@ mod test { &Boolean::Not(_) => panic!(), &Boolean::Constant(b) => { assert!(b == (expected & 1 == 1)); - }, + } } expected >>= 1; @@ -474,10 +512,10 @@ mod test { let mut expected = (a ^ b).wrapping_add(c).wrapping_add(d); - let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), Some(a)).unwrap(); + let a_bit = UInt64::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap(); let b_bit = UInt64::constant(b); let c_bit = UInt64::constant(c); - let d_bit = UInt64::alloc(cs.ns(|| "d_bit"), Some(d)).unwrap(); + let d_bit = UInt64::alloc(cs.ns(|| "d_bit"), || Ok(d)).unwrap(); let r = a_bit.xor(cs.ns(|| "xor"), &b_bit).unwrap(); let r = UInt64::addmany(cs.ns(|| "addition"), &[r, c_bit, d_bit]).unwrap(); @@ -490,10 +528,10 @@ mod test { match b { &Boolean::Is(ref b) => { assert!(b.get_value().unwrap() == (expected & 1 == 1)); - }, + } &Boolean::Not(ref b) => { assert!(!b.get_value().unwrap() == (expected & 1 == 1)); - }, + } &Boolean::Constant(_) => unreachable!(), } @@ -529,7 +567,7 @@ mod test { match b { &Boolean::Constant(b) => { assert_eq!(b, tmp & 1 == 1); - }, + } _ => unreachable!(), }