Reduce allocations in UInts

This commit is contained in:
Pratyush Mishra
2020-12-08 22:56:14 -08:00
parent 905e7284b2
commit 0fd45d3d83
3 changed files with 88 additions and 98 deletions

View File

@@ -829,11 +829,10 @@ impl<F: Field> ToBytesGadget<F> for Boolean<F> {
/// Outputs `1u8` if `self` is true, and `0u8` otherwise. /// Outputs `1u8` if `self` is true, and `0u8` otherwise.
#[tracing::instrument(target = "r1cs")] #[tracing::instrument(target = "r1cs")]
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> { fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bits = vec![self.clone()];
bits.extend(vec![Boolean::constant(false); 7]);
let value = self.value().map(u8::from).ok(); let value = self.value().map(u8::from).ok();
let byte = UInt8 { bits, value }; let mut bits = [Boolean::FALSE; 8];
Ok(vec![byte]) bits[0] = self.clone();
Ok(vec![UInt8 { bits, value }])
} }
} }

View File

@@ -33,7 +33,7 @@ macro_rules! make_uint {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct $name<F: Field> { pub struct $name<F: Field> {
// Least significant bit first // Least significant bit first
bits: Vec<Boolean<F>>, bits: [Boolean<F>; $size],
value: Option<$native>, value: Option<$native>,
} }
@@ -41,7 +41,7 @@ macro_rules! make_uint {
type Value = $native; type Value = $native;
fn cs(&self) -> ConstraintSystemRef<F> { fn cs(&self) -> ConstraintSystemRef<F> {
self.bits.as_slice().cs() self.bits.as_ref().cs()
} }
fn value(&self) -> Result<Self::Value, SynthesisError> { fn value(&self) -> Result<Self::Value, SynthesisError> {
@@ -65,16 +65,11 @@ macro_rules! make_uint {
#[doc = $native_doc_name] #[doc = $native_doc_name]
#[doc = "` type."] #[doc = "` type."]
pub fn constant(value: $native) -> Self { pub fn constant(value: $native) -> Self {
let mut bits = Vec::with_capacity($size); let mut bits = [Boolean::FALSE; $size];
let mut tmp = value; let mut tmp = value;
for _ in 0..$size { for i in 0..$size {
if tmp & 1 == 1 { bits[i] = Boolean::constant((tmp & 1) == 1);
bits.push(Boolean::constant(true))
} else {
bits.push(Boolean::constant(false))
}
tmp >>= 1; tmp >>= 1;
} }
@@ -86,7 +81,7 @@ macro_rules! make_uint {
/// Turns `self` into the underlying little-endian bits. /// Turns `self` into the underlying little-endian bits.
pub fn to_bits_le(&self) -> Vec<Boolean<F>> { pub fn to_bits_le(&self) -> Vec<Boolean<F>> {
self.bits.clone() self.bits.to_vec()
} }
/// Construct `Self` from a slice of `Boolean`s. /// Construct `Self` from a slice of `Boolean`s.
@@ -99,7 +94,7 @@ macro_rules! make_uint {
pub fn from_bits_le(bits: &[Boolean<F>]) -> Self { pub fn from_bits_le(bits: &[Boolean<F>]) -> Self {
assert_eq!(bits.len(), $size); assert_eq!(bits.len(), $size);
let bits = bits.to_vec(); let bits = <&[Boolean<F>; $size]>::try_from(bits).unwrap().clone();
let mut value = Some(0); let mut value = Some(0);
for b in bits.iter().rev() { for b in bits.iter().rev() {
@@ -130,23 +125,22 @@ macro_rules! make_uint {
/// Rotates `self` to the right by `by` steps, wrapping around. /// Rotates `self` to the right by `by` steps, wrapping around.
#[tracing::instrument(target = "r1cs", skip(self))] #[tracing::instrument(target = "r1cs", skip(self))]
pub fn rotr(&self, by: usize) -> Self { pub fn rotr(&self, by: usize) -> Self {
let mut result = self.clone();
let by = by % $size; let by = by % $size;
let new_bits = self let new_bits = self
.bits .bits
.iter() .iter()
.skip(by) .skip(by)
.chain(self.bits.iter()) .chain(&self.bits)
.take($size) .take($size);
.cloned()
.collect();
$name { for (res, new) in result.bits.iter_mut().zip(new_bits) {
bits: new_bits, *res = new.clone();
value: self
.value
.map(|v| v.rotate_right(u32::try_from(by).unwrap())),
} }
result.value = self.value.map(|v| v.rotate_right(u32::try_from(by).unwrap()));
result
} }
/// Outputs `self ^ other`. /// Outputs `self ^ other`.
@@ -155,22 +149,19 @@ macro_rules! make_uint {
/// *does not* create any constraints or variables. /// *does not* create any constraints or variables.
#[tracing::instrument(target = "r1cs", skip(self, other))] #[tracing::instrument(target = "r1cs", skip(self, other))]
pub fn xor(&self, other: &Self) -> Result<Self, SynthesisError> { pub fn xor(&self, other: &Self) -> Result<Self, SynthesisError> {
let new_value = match (self.value, other.value) { let mut result = self.clone();
result.value = match (self.value, other.value) {
(Some(a), Some(b)) => Some(a ^ b), (Some(a), Some(b)) => Some(a ^ b),
_ => None, _ => None,
}; };
let bits = self let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b));
.bits
.iter()
.zip(other.bits.iter())
.map(|(a, b)| a.xor(b))
.collect::<Result<_, _>>()?;
Ok($name { for (res, new) in result.bits.iter_mut().zip(new_bits) {
bits, *res = new?;
value: new_value, }
})
Ok(result)
} }
/// Perform modular addition of `operands`. /// Perform modular addition of `operands`.
@@ -292,9 +283,10 @@ macro_rules! make_uint {
// Discard carry bits that we don't care about // Discard carry bits that we don't care about
result_bits.truncate($size); result_bits.truncate($size);
let bits = TryFrom::try_from(result_bits).unwrap();
Ok($name { Ok($name {
bits: result_bits, bits,
value: modular_value, value: modular_value,
}) })
} }
@@ -314,7 +306,7 @@ macro_rules! make_uint {
impl<ConstraintF: Field> EqGadget<ConstraintF> for $name<ConstraintF> { impl<ConstraintF: Field> EqGadget<ConstraintF> for $name<ConstraintF> {
#[tracing::instrument(target = "r1cs", skip(self))] #[tracing::instrument(target = "r1cs", skip(self))]
fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> { fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
self.bits.as_slice().is_eq(&other.bits) self.bits.as_ref().is_eq(&other.bits)
} }
#[tracing::instrument(target = "r1cs", skip(self))] #[tracing::instrument(target = "r1cs", skip(self))]
@@ -348,19 +340,20 @@ macro_rules! make_uint {
.bits .bits
.iter() .iter()
.zip(&false_value.bits) .zip(&false_value.bits)
.map(|(t, f)| cond.select(t, f)) .map(|(t, f)| cond.select(t, f));
.collect::<Result<Vec<_>, SynthesisError>>()?; let mut bits = [Boolean::FALSE; $size];
let selected_value = cond.value().ok().and_then(|cond| { for (result, new) in bits.iter_mut().zip(selected_bits) {
*result = new?;
}
let value = cond.value().ok().and_then(|cond| {
if cond { if cond {
true_value.value().ok() true_value.value().ok()
} else { } else {
false_value.value().ok() false_value.value().ok()
} }
}); });
Ok(Self { Ok(Self { bits, value })
bits: selected_bits,
value: selected_value,
})
} }
} }
@@ -372,19 +365,18 @@ macro_rules! make_uint {
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
let ns = cs.into(); let ns = cs.into();
let cs = ns.cs(); let cs = ns.cs();
let value = f().map(|f| *f.borrow()); let value = f().map(|f| *f.borrow()).ok();
let values = match value {
Ok(val) => (0..$size).map(|i| Some((val >> i) & 1 == 1)).collect(), let mut values = [None; $size];
_ => vec![None; $size], if let Some(val) = value {
}; values.iter_mut().enumerate().for_each(|(i, v)| *v = Some((val >> i) & 1 == 1));
let bits = values }
.into_iter()
.map(|v| Boolean::new_variable(cs.clone(), || v.get(), mode)) let mut bits = [Boolean::FALSE; $size];
.collect::<Result<Vec<_>, _>>()?; for (b, v) in bits.iter_mut().zip(&values) {
Ok(Self { *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?;
bits, }
value: value.ok(), Ok(Self { bits, value })
})
} }
} }

View File

@@ -3,14 +3,14 @@ use ark_ff::{Field, FpParameters, PrimeField, ToConstraintField};
use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError}; use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError};
use crate::{fields::fp::AllocatedFp, prelude::*, Assignment, Vec}; use crate::{fields::fp::AllocatedFp, prelude::*, Assignment, Vec};
use core::borrow::Borrow; use core::{borrow::Borrow, convert::TryFrom};
/// Represents an interpretation of 8 `Boolean` objects as an /// Represents an interpretation of 8 `Boolean` objects as an
/// unsigned integer. /// unsigned integer.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct UInt8<F: Field> { pub struct UInt8<F: Field> {
/// Little-endian representation: least significant bit first /// Little-endian representation: least significant bit first
pub(crate) bits: Vec<Boolean<F>>, pub(crate) bits: [Boolean<F>; 8],
pub(crate) value: Option<u8>, pub(crate) value: Option<u8>,
} }
@@ -18,7 +18,7 @@ impl<F: Field> R1CSVar<F> for UInt8<F> {
type Value = u8; type Value = u8;
fn cs(&self) -> ConstraintSystemRef<F> { fn cs(&self) -> ConstraintSystemRef<F> {
self.bits.as_slice().cs() self.bits.as_ref().cs()
} }
fn value(&self) -> Result<Self::Value, SynthesisError> { fn value(&self) -> Result<Self::Value, SynthesisError> {
@@ -84,12 +84,12 @@ impl<F: Field> UInt8<F> {
/// # } /// # }
/// ``` /// ```
pub fn constant(value: u8) -> Self { pub fn constant(value: u8) -> Self {
let mut bits = Vec::with_capacity(8); let mut bits = [Boolean::FALSE; 8];
let mut tmp = value; let mut tmp = value;
for _ in 0..8 { for i in 0..8 {
// If last bit is one, push one. // If last bit is one, push one.
bits.push(Boolean::constant(tmp & 1 == 1)); bits[i] = Boolean::constant((tmp & 1) == 1);
tmp >>= 1; tmp >>= 1;
} }
@@ -201,8 +201,7 @@ impl<F: Field> UInt8<F> {
#[tracing::instrument(target = "r1cs")] #[tracing::instrument(target = "r1cs")]
pub fn from_bits_le(bits: &[Boolean<F>]) -> Self { pub fn from_bits_le(bits: &[Boolean<F>]) -> Self {
assert_eq!(bits.len(), 8); assert_eq!(bits.len(), 8);
let bits = <&[Boolean<F>; 8]>::try_from(bits).unwrap().clone();
let bits = bits.to_vec();
let mut value = Some(0u8); let mut value = Some(0u8);
for (i, b) in bits.iter().enumerate() { for (i, b) in bits.iter().enumerate() {
@@ -239,29 +238,26 @@ impl<F: Field> UInt8<F> {
/// ``` /// ```
#[tracing::instrument(target = "r1cs")] #[tracing::instrument(target = "r1cs")]
pub fn xor(&self, other: &Self) -> Result<Self, SynthesisError> { pub fn xor(&self, other: &Self) -> Result<Self, SynthesisError> {
let new_value = match (self.value, other.value) { let mut result = self.clone();
result.value = match (self.value, other.value) {
(Some(a), Some(b)) => Some(a ^ b), (Some(a), Some(b)) => Some(a ^ b),
_ => None, _ => None,
}; };
let bits = self let new_bits = self.bits.iter().zip(&other.bits).map(|(a, b)| a.xor(b));
.bits
.iter()
.zip(other.bits.iter())
.map(|(a, b)| a.xor(b))
.collect::<Result<_, _>>()?;
Ok(Self { for (res, new) in result.bits.iter_mut().zip(new_bits) {
bits, *res = new?;
value: new_value, }
})
Ok(result)
} }
} }
impl<ConstraintF: Field> EqGadget<ConstraintF> for UInt8<ConstraintF> { impl<ConstraintF: Field> EqGadget<ConstraintF> for UInt8<ConstraintF> {
#[tracing::instrument(target = "r1cs")] #[tracing::instrument(target = "r1cs")]
fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> { fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
self.bits.as_slice().is_eq(&other.bits) self.bits.as_ref().is_eq(&other.bits)
} }
#[tracing::instrument(target = "r1cs")] #[tracing::instrument(target = "r1cs")]
@@ -295,19 +291,20 @@ impl<ConstraintF: Field> CondSelectGadget<ConstraintF> for UInt8<ConstraintF> {
.bits .bits
.iter() .iter()
.zip(&false_value.bits) .zip(&false_value.bits)
.map(|(t, f)| cond.select(t, f)) .map(|(t, f)| cond.select(t, f));
.collect::<Result<Vec<_>, SynthesisError>>()?; let mut bits = [Boolean::FALSE; 8];
let selected_value = cond.value().ok().and_then(|cond| { for (result, new) in bits.iter_mut().zip(selected_bits) {
*result = new?;
}
let value = cond.value().ok().and_then(|cond| {
if cond { if cond {
true_value.value().ok() true_value.value().ok()
} else { } else {
false_value.value().ok() false_value.value().ok()
} }
}); });
Ok(Self { Ok(Self { bits, value })
bits: selected_bits,
value: selected_value,
})
} }
} }
@@ -319,19 +316,21 @@ impl<ConstraintF: Field> AllocVar<u8, ConstraintF> for UInt8<ConstraintF> {
) -> Result<Self, SynthesisError> { ) -> Result<Self, SynthesisError> {
let ns = cs.into(); let ns = cs.into();
let cs = ns.cs(); let cs = ns.cs();
let value = f().map(|f| *f.borrow()); let value = f().map(|f| *f.borrow()).ok();
let values = match value {
Ok(val) => (0..8).map(|i| Some((val >> i) & 1 == 1)).collect(), let mut values = [None; 8];
_ => vec![None; 8], if let Some(val) = value {
}; values
let bits = values .iter_mut()
.into_iter() .enumerate()
.map(|v| Boolean::new_variable(cs.clone(), || v.get(), mode)) .for_each(|(i, v)| *v = Some((val >> i) & 1 == 1));
.collect::<Result<Vec<_>, _>>()?; }
Ok(Self {
bits, let mut bits = [Boolean::FALSE; 8];
value: value.ok(), for (b, v) in bits.iter_mut().zip(&values) {
}) *b = Boolean::new_variable(cs.clone(), || v.get(), mode)?;
}
Ok(Self { bits, value })
} }
} }