From 0fd45d3d83b56ddaaa1f35ba8262c2696512566f Mon Sep 17 00:00:00 2001 From: Pratyush Mishra Date: Tue, 8 Dec 2020 22:56:14 -0800 Subject: [PATCH] Reduce allocations in `UInt`s --- src/bits/boolean.rs | 7 ++-- src/bits/uint.rs | 100 ++++++++++++++++++++------------------------ src/bits/uint8.rs | 79 +++++++++++++++++----------------- 3 files changed, 88 insertions(+), 98 deletions(-) diff --git a/src/bits/boolean.rs b/src/bits/boolean.rs index 244ec85..294ef62 100644 --- a/src/bits/boolean.rs +++ b/src/bits/boolean.rs @@ -829,11 +829,10 @@ impl ToBytesGadget for Boolean { /// Outputs `1u8` if `self` is true, and `0u8` otherwise. #[tracing::instrument(target = "r1cs")] fn to_bytes(&self) -> Result>, SynthesisError> { - let mut bits = vec![self.clone()]; - bits.extend(vec![Boolean::constant(false); 7]); let value = self.value().map(u8::from).ok(); - let byte = UInt8 { bits, value }; - Ok(vec![byte]) + let mut bits = [Boolean::FALSE; 8]; + bits[0] = self.clone(); + Ok(vec![UInt8 { bits, value }]) } } diff --git a/src/bits/uint.rs b/src/bits/uint.rs index 40ff267..880238e 100644 --- a/src/bits/uint.rs +++ b/src/bits/uint.rs @@ -33,7 +33,7 @@ macro_rules! make_uint { #[derive(Clone, Debug)] pub struct $name { // Least significant bit first - bits: Vec>, + bits: [Boolean; $size], value: Option<$native>, } @@ -41,7 +41,7 @@ macro_rules! make_uint { type Value = $native; fn cs(&self) -> ConstraintSystemRef { - self.bits.as_slice().cs() + self.bits.as_ref().cs() } fn value(&self) -> Result { @@ -65,16 +65,11 @@ macro_rules! make_uint { #[doc = $native_doc_name] #[doc = "` type."] pub fn constant(value: $native) -> Self { - let mut bits = Vec::with_capacity($size); + let mut bits = [Boolean::FALSE; $size]; let mut tmp = value; - for _ in 0..$size { - if tmp & 1 == 1 { - bits.push(Boolean::constant(true)) - } else { - bits.push(Boolean::constant(false)) - } - + for i in 0..$size { + bits[i] = Boolean::constant((tmp & 1) == 1); tmp >>= 1; } @@ -86,7 +81,7 @@ macro_rules! make_uint { /// Turns `self` into the underlying little-endian bits. pub fn to_bits_le(&self) -> Vec> { - self.bits.clone() + self.bits.to_vec() } /// Construct `Self` from a slice of `Boolean`s. @@ -99,7 +94,7 @@ macro_rules! make_uint { pub fn from_bits_le(bits: &[Boolean]) -> Self { assert_eq!(bits.len(), $size); - let bits = bits.to_vec(); + let bits = <&[Boolean; $size]>::try_from(bits).unwrap().clone(); let mut value = Some(0); for b in bits.iter().rev() { @@ -130,23 +125,22 @@ macro_rules! make_uint { /// Rotates `self` to the right by `by` steps, wrapping around. #[tracing::instrument(target = "r1cs", skip(self))] pub fn rotr(&self, by: usize) -> Self { + let mut result = self.clone(); let by = by % $size; let new_bits = self .bits .iter() .skip(by) - .chain(self.bits.iter()) - .take($size) - .cloned() - .collect(); + .chain(&self.bits) + .take($size); - $name { - bits: new_bits, - value: self - .value - .map(|v| v.rotate_right(u32::try_from(by).unwrap())), + for (res, new) in result.bits.iter_mut().zip(new_bits) { + *res = new.clone(); } + + result.value = self.value.map(|v| v.rotate_right(u32::try_from(by).unwrap())); + result } /// Outputs `self ^ other`. @@ -155,22 +149,19 @@ macro_rules! make_uint { /// *does not* create any constraints or variables. #[tracing::instrument(target = "r1cs", skip(self, other))] pub fn xor(&self, other: &Self) -> Result { - let new_value = match (self.value, other.value) { + let mut result = self.clone(); + result.value = match (self.value, other.value) { (Some(a), Some(b)) => Some(a ^ b), _ => None, }; - let bits = self - .bits - .iter() - .zip(other.bits.iter()) - .map(|(a, b)| a.xor(b)) - .collect::>()?; + let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - Ok($name { - bits, - value: new_value, - }) + for (res, new) in result.bits.iter_mut().zip(new_bits) { + *res = new?; + } + + Ok(result) } /// Perform modular addition of `operands`. @@ -292,9 +283,10 @@ macro_rules! make_uint { // Discard carry bits that we don't care about result_bits.truncate($size); + let bits = TryFrom::try_from(result_bits).unwrap(); Ok($name { - bits: result_bits, + bits, value: modular_value, }) } @@ -314,7 +306,7 @@ macro_rules! make_uint { impl EqGadget for $name { #[tracing::instrument(target = "r1cs", skip(self))] fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_slice().is_eq(&other.bits) + self.bits.as_ref().is_eq(&other.bits) } #[tracing::instrument(target = "r1cs", skip(self))] @@ -348,19 +340,20 @@ macro_rules! make_uint { .bits .iter() .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)) - .collect::, SynthesisError>>()?; - let selected_value = cond.value().ok().and_then(|cond| { + .map(|(t, f)| cond.select(t, f)); + let mut bits = [Boolean::FALSE; $size]; + for (result, new) in bits.iter_mut().zip(selected_bits) { + *result = new?; + } + + let value = cond.value().ok().and_then(|cond| { if cond { true_value.value().ok() } else { false_value.value().ok() } }); - Ok(Self { - bits: selected_bits, - value: selected_value, - }) + Ok(Self { bits, value }) } } @@ -372,19 +365,18 @@ macro_rules! make_uint { ) -> Result { let ns = cs.into(); let cs = ns.cs(); - let value = f().map(|f| *f.borrow()); - let values = match value { - Ok(val) => (0..$size).map(|i| Some((val >> i) & 1 == 1)).collect(), - _ => vec![None; $size], - }; - let bits = values - .into_iter() - .map(|v| Boolean::new_variable(cs.clone(), || v.get(), mode)) - .collect::, _>>()?; - Ok(Self { - bits, - value: value.ok(), - }) + let value = f().map(|f| *f.borrow()).ok(); + + let mut values = [None; $size]; + if let Some(val) = value { + values.iter_mut().enumerate().for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); + } + + let mut bits = [Boolean::FALSE; $size]; + for (b, v) in bits.iter_mut().zip(&values) { + *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; + } + Ok(Self { bits, value }) } } diff --git a/src/bits/uint8.rs b/src/bits/uint8.rs index 1ccd7f1..241244a 100644 --- a/src/bits/uint8.rs +++ b/src/bits/uint8.rs @@ -3,14 +3,14 @@ use ark_ff::{Field, FpParameters, PrimeField, ToConstraintField}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use crate::{fields::fp::AllocatedFp, prelude::*, Assignment, Vec}; -use core::borrow::Borrow; +use core::{borrow::Borrow, convert::TryFrom}; /// Represents an interpretation of 8 `Boolean` objects as an /// unsigned integer. #[derive(Clone, Debug)] pub struct UInt8 { /// Little-endian representation: least significant bit first - pub(crate) bits: Vec>, + pub(crate) bits: [Boolean; 8], pub(crate) value: Option, } @@ -18,7 +18,7 @@ impl R1CSVar for UInt8 { type Value = u8; fn cs(&self) -> ConstraintSystemRef { - self.bits.as_slice().cs() + self.bits.as_ref().cs() } fn value(&self) -> Result { @@ -84,12 +84,12 @@ impl UInt8 { /// # } /// ``` pub fn constant(value: u8) -> Self { - let mut bits = Vec::with_capacity(8); + let mut bits = [Boolean::FALSE; 8]; let mut tmp = value; - for _ in 0..8 { + for i in 0..8 { // If last bit is one, push one. - bits.push(Boolean::constant(tmp & 1 == 1)); + bits[i] = Boolean::constant((tmp & 1) == 1); tmp >>= 1; } @@ -201,8 +201,7 @@ impl UInt8 { #[tracing::instrument(target = "r1cs")] pub fn from_bits_le(bits: &[Boolean]) -> Self { assert_eq!(bits.len(), 8); - - let bits = bits.to_vec(); + let bits = <&[Boolean; 8]>::try_from(bits).unwrap().clone(); let mut value = Some(0u8); for (i, b) in bits.iter().enumerate() { @@ -239,29 +238,26 @@ impl UInt8 { /// ``` #[tracing::instrument(target = "r1cs")] pub fn xor(&self, other: &Self) -> Result { - let new_value = match (self.value, other.value) { + let mut result = self.clone(); + result.value = match (self.value, other.value) { (Some(a), Some(b)) => Some(a ^ b), _ => None, }; - let bits = self - .bits - .iter() - .zip(other.bits.iter()) - .map(|(a, b)| a.xor(b)) - .collect::>()?; + let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b)); - Ok(Self { - bits, - value: new_value, - }) + for (res, new) in result.bits.iter_mut().zip(new_bits) { + *res = new?; + } + + Ok(result) } } impl EqGadget for UInt8 { #[tracing::instrument(target = "r1cs")] fn is_eq(&self, other: &Self) -> Result, SynthesisError> { - self.bits.as_slice().is_eq(&other.bits) + self.bits.as_ref().is_eq(&other.bits) } #[tracing::instrument(target = "r1cs")] @@ -295,19 +291,20 @@ impl CondSelectGadget for UInt8 { .bits .iter() .zip(&false_value.bits) - .map(|(t, f)| cond.select(t, f)) - .collect::, SynthesisError>>()?; - let selected_value = cond.value().ok().and_then(|cond| { + .map(|(t, f)| cond.select(t, f)); + let mut bits = [Boolean::FALSE; 8]; + for (result, new) in bits.iter_mut().zip(selected_bits) { + *result = new?; + } + + let value = cond.value().ok().and_then(|cond| { if cond { true_value.value().ok() } else { false_value.value().ok() } }); - Ok(Self { - bits: selected_bits, - value: selected_value, - }) + Ok(Self { bits, value }) } } @@ -319,19 +316,21 @@ impl AllocVar for UInt8 { ) -> Result { let ns = cs.into(); let cs = ns.cs(); - let value = f().map(|f| *f.borrow()); - let values = match value { - Ok(val) => (0..8).map(|i| Some((val >> i) & 1 == 1)).collect(), - _ => vec![None; 8], - }; - let bits = values - .into_iter() - .map(|v| Boolean::new_variable(cs.clone(), || v.get(), mode)) - .collect::, _>>()?; - Ok(Self { - bits, - value: value.ok(), - }) + let value = f().map(|f| *f.borrow()).ok(); + + let mut values = [None; 8]; + if let Some(val) = value { + values + .iter_mut() + .enumerate() + .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1)); + } + + let mut bits = [Boolean::FALSE; 8]; + for (b, v) in bits.iter_mut().zip(&values) { + *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?; + } + Ok(Self { bits, value }) } }