Add back ToBytesGadget and ToBitsGadget to prelude (#136)

This commit is contained in:
Pratyush Mishra
2024-01-06 16:51:55 -05:00
committed by GitHub
parent d011859416
commit a12499518c
26 changed files with 874 additions and 231 deletions

View File

@@ -7,11 +7,16 @@ use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _and(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
result._and_in_place(other)?;
Ok(result)
}
fn _and_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
for (a, b) in self.bits.iter_mut().zip(&other.bits) {
*a &= b;
}
result.value = self.value.and_then(|a| Some(a & other.value?));
Ok(result)
self.value = self.value.and_then(|a| Some(a & other.value?));
Ok(())
}
}
@@ -70,8 +75,9 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a Self> for UInt<N, T,
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &Self) -> Self::Output {
self._and(&other).unwrap()
fn bitand(mut self, other: &Self) -> Self::Output {
self._and_in_place(other).unwrap();
self
}
}
@@ -102,7 +108,7 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<UInt<N, T, F>> for &'a UI
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: UInt<N, T, F>) -> Self::Output {
self._and(&other).unwrap()
other & self
}
}
@@ -133,7 +139,43 @@ impl<const N: usize, T: PrimUInt, F: Field> BitAnd<Self> for UInt<N, T, F> {
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: Self) -> Self::Output {
self._and(&other).unwrap()
self & &other
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<T> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: T) -> Self::Output {
self & UInt::constant(other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a T> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &'a T) -> Self::Output {
self & UInt::constant(*other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a T> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &'a T) -> Self::Output {
self & UInt::constant(*other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<T> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: T) -> Self::Output {
self & UInt::constant(other)
}
}
@@ -163,8 +205,7 @@ impl<const N: usize, T: PrimUInt, F: Field> BitAndAssign<Self> for UInt<N, T, F>
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: Self) {
let result = self._and(&other).unwrap();
*self = result;
self._and_in_place(&other).unwrap();
}
}
@@ -194,8 +235,21 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a Self> for UInt<
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: &'a Self) {
let result = self._and(other).unwrap();
*self = result;
self._and_in_place(&other).unwrap();
}
}
impl<const N: usize, T: PrimUInt, F: Field> BitAndAssign<T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: T) {
*self &= &Self::constant(other);
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: &'a T) {
*self &= &Self::constant(*other);
}
}
@@ -205,7 +259,7 @@ mod tests {
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
R1CSVar,
};
use ark_ff::PrimeField;
@@ -236,28 +290,65 @@ mod tests {
Ok(())
}
fn uint_and_native<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = &a & b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? & b), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_and() {
run_binary_exhaustive(uint_and::<u8, 8, Fr>).unwrap()
run_binary_exhaustive_both(uint_and::<u8, 8, Fr>, uint_and_native::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_and() {
run_binary_random::<1000, 16, _, _>(uint_and::<u16, 16, Fr>).unwrap()
run_binary_random_both::<1000, 16, _, _>(
uint_and::<u16, 16, Fr>,
uint_and_native::<u16, 16, Fr>,
)
.unwrap()
}
#[test]
fn u32_and() {
run_binary_random::<1000, 32, _, _>(uint_and::<u32, 32, Fr>).unwrap()
run_binary_random_both::<1000, 32, _, _>(
uint_and::<u32, 32, Fr>,
uint_and_native::<u32, 32, Fr>,
)
.unwrap()
}
#[test]
fn u64_and() {
run_binary_random::<1000, 64, _, _>(uint_and::<u64, 64, Fr>).unwrap()
run_binary_random_both::<1000, 64, _, _>(
uint_and::<u64, 64, Fr>,
uint_and_native::<u64, 64, Fr>,
)
.unwrap()
}
#[test]
fn u128_and() {
run_binary_random::<1000, 128, _, _>(uint_and::<u128, 128, Fr>).unwrap()
run_binary_random_both::<1000, 128, _, _>(
uint_and::<u128, 128, Fr>,
uint_and_native::<u128, 128, Fr>,
)
.unwrap()
}
}

View File

@@ -72,6 +72,71 @@ impl<const N: usize, F: Field, T: PrimUInt> UInt<N, T, F> {
let value = value_exists.then_some(value);
Self { bits, value }
}
/// Converts a big-endian list of bytes into a `UInt`.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let var = UInt16::new_witness(cs.clone(), || Ok(2 * (u8::MAX as u16)))?;
///
/// // Construct u8::MAX * 2
/// let bytes = UInt8::constant_vec(&(2 * (u8::MAX as u16)).to_be_bytes());
///
/// let c = UInt16::from_bytes_be(&bytes)?;
/// var.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
pub fn from_bytes_be(bytes: &[UInt8<F>]) -> Result<Self, SynthesisError> {
let bits = bytes
.iter()
.rev()
.flat_map(|b| b.to_bits_le().unwrap())
.collect::<Vec<_>>();
Ok(Self::from_bits_le(&bits))
}
/// Converts a little-endian byte order list of bytes into a `UInt`.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let var = UInt16::new_witness(cs.clone(), || Ok(2 * (u8::MAX as u16)))?;
///
/// // Construct u8::MAX * 2
/// let bytes = UInt8::constant_vec(&(2 * (u8::MAX as u16)).to_le_bytes());
///
/// let c = UInt16::from_bytes_le(&bytes)?;
/// var.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
pub fn from_bytes_le(bytes: &[UInt8<F>]) -> Result<Self, SynthesisError> {
let bits = bytes
.iter()
.flat_map(|b| b.to_bits_le().unwrap())
.collect::<Vec<_>>();
Ok(Self::from_bits_le(&bits))
}
pub fn to_bytes_be(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = self.to_bytes_le()?;
bytes.reverse();
Ok(bytes)
}
}
impl<const N: usize, T: PrimUInt, F: Field> ToBitsGadget<F> for UInt<N, T, F> {
@@ -97,7 +162,7 @@ impl<const N: usize, T: PrimUInt, ConstraintF: Field> ToBytesGadget<ConstraintF>
for UInt<N, T, ConstraintF>
{
#[tracing::instrument(target = "r1cs", skip(self))]
fn to_bytes(&self) -> Result<Vec<UInt8<ConstraintF>>, SynthesisError> {
fn to_bytes_le(&self) -> Result<Vec<UInt8<ConstraintF>>, SynthesisError> {
Ok(self
.to_bits_le()?
.chunks(8)
@@ -107,23 +172,209 @@ impl<const N: usize, T: PrimUInt, ConstraintF: Field> ToBytesGadget<ConstraintF>
}
impl<const N: usize, T: PrimUInt, F: Field> ToBytesGadget<F> for [UInt<N, T, F>] {
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = Vec::with_capacity(self.len() * (N / 8));
for elem in self {
bytes.extend_from_slice(&elem.to_bytes()?);
bytes.extend_from_slice(&elem.to_bytes_le()?);
}
Ok(bytes)
}
}
impl<const N: usize, T: PrimUInt, F: Field> ToBytesGadget<F> for Vec<UInt<N, T, F>> {
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes()
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes_le()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> ToBytesGadget<F> for &'a [UInt<N, T, F>] {
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
(*self).to_bytes()
fn to_bytes_le(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
(*self).to_bytes_le()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prelude::EqGadget,
uint::test_utils::{run_unary_exhaustive, run_unary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_to_bytes_le<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = a.to_bytes_le()?;
let expected = UInt8::constant_vec(a.value()?.to_le_bytes().as_ref());
assert_eq!(expected.len(), computed.len());
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_to_bytes_be<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = a.to_bytes_be()?;
let expected = UInt8::constant_vec(a.value()?.to_be_bytes().as_ref());
assert_eq!(expected.len(), computed.len());
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_from_bytes_le<T: PrimUInt, const N: usize, F: PrimeField>(
expected: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = expected.cs();
let mode = if expected.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let computed = {
let value = expected.value()?.to_le_bytes();
let a = Vec::<UInt8<F>>::new_variable(cs.clone(), || Ok(value.as_ref()), mode)?;
UInt::from_bytes_le(&a)?
};
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !expected.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_from_bytes_be<T: PrimUInt, const N: usize, F: PrimeField>(
expected: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = expected.cs();
let mode = if expected.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let computed = {
let value = expected.value()?.to_be_bytes();
let a = Vec::<UInt8<F>>::new_variable(cs.clone(), || Ok(value.as_ref()), mode)?;
UInt::from_bytes_be(&a)?
};
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !expected.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_to_bytes_le() {
run_unary_exhaustive(uint_to_bytes_le::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_to_bytes_le() {
run_unary_random::<1000, 16, _, _>(uint_to_bytes_le::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_to_bytes_le() {
run_unary_random::<1000, 32, _, _>(uint_to_bytes_le::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_to_bytes_le() {
run_unary_random::<1000, 64, _, _>(uint_to_bytes_le::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_to_bytes_le() {
run_unary_random::<1000, 128, _, _>(uint_to_bytes_le::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_to_bytes_be() {
run_unary_exhaustive(uint_to_bytes_be::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_to_bytes_be() {
run_unary_random::<1000, 16, _, _>(uint_to_bytes_be::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_to_bytes_be() {
run_unary_random::<1000, 32, _, _>(uint_to_bytes_be::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_to_bytes_be() {
run_unary_random::<1000, 64, _, _>(uint_to_bytes_be::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_to_bytes_be() {
run_unary_random::<1000, 128, _, _>(uint_to_bytes_be::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_from_bytes_le() {
run_unary_exhaustive(uint_from_bytes_le::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_from_bytes_le() {
run_unary_random::<1000, 16, _, _>(uint_from_bytes_le::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_from_bytes_le() {
run_unary_random::<1000, 32, _, _>(uint_from_bytes_le::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_from_bytes_le() {
run_unary_random::<1000, 64, _, _>(uint_from_bytes_le::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_from_bytes_le() {
run_unary_random::<1000, 128, _, _>(uint_from_bytes_le::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_from_bytes_be() {
run_unary_exhaustive(uint_from_bytes_be::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_from_bytes_be() {
run_unary_random::<1000, 16, _, _>(uint_from_bytes_be::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_from_bytes_be() {
run_unary_random::<1000, 32, _, _>(uint_from_bytes_be::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_from_bytes_be() {
run_unary_random::<1000, 64, _, _>(uint_from_bytes_be::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_from_bytes_be() {
run_unary_random::<1000, 128, _, _>(uint_from_bytes_be::<u128, 128, Fr>).unwrap()
}
}

View File

@@ -7,12 +7,17 @@ use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _not(&self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for a in &mut result.bits {
*a = !&*a
}
result.value = self.value.map(Not::not);
result._not_in_place()?;
Ok(result)
}
fn _not_in_place(&mut self) -> Result<(), SynthesisError> {
for a in &mut self.bits {
a.not_in_place()?;
}
self.value = self.value.map(Not::not);
Ok(())
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> Not for &'a UInt<N, T, F> {
@@ -67,8 +72,9 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> Not for UInt<N, T, F> {
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
fn not(self) -> Self::Output {
self._not().unwrap()
fn not(mut self) -> Self::Output {
self._not_in_place().unwrap();
self
}
}

View File

@@ -7,11 +7,16 @@ use super::{PrimUInt, UInt};
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
fn _or(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
result._or_in_place(other)?;
Ok(result)
}
fn _or_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
for (a, b) in self.bits.iter_mut().zip(&other.bits) {
*a |= b;
}
result.value = self.value.and_then(|a| Some(a | other.value?));
Ok(result)
self.value = self.value.and_then(|a| Some(a | other.value?));
Ok(())
}
}
@@ -50,8 +55,9 @@ impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a Self> for UInt<N,
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: &Self) -> Self::Output {
self._or(&other).unwrap()
fn bitor(mut self, other: &Self) -> Self::Output {
self._or_in_place(&other).unwrap();
self
}
}
@@ -60,7 +66,7 @@ impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<UInt<N, T, F>> for &'
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: UInt<N, T, F>) -> Self::Output {
self._or(&other).unwrap()
other | self
}
}
@@ -69,7 +75,43 @@ impl<const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: Self) -> Self::Output {
self._or(&other).unwrap()
self | &other
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<T> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: T) -> Self::Output {
self | &UInt::constant(other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a T> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: &'a T) -> Self::Output {
self | &UInt::constant(*other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a T> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: &'a T) -> Self::Output {
self | &UInt::constant(*other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<T> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: T) -> Self::Output {
self | &UInt::constant(other)
}
}
@@ -99,16 +141,28 @@ impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<Self> for UInt<N, T
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor_assign(&mut self, other: Self) {
let result = self._or(&other).unwrap();
*self = result;
self._or_in_place(&other).unwrap();
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a Self> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor_assign(&mut self, other: &'a Self) {
let result = self._or(other).unwrap();
*self = result;
self._or_in_place(other).unwrap();
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor_assign(&mut self, other: T) {
*self |= &UInt::constant(other);
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor_assign(&mut self, other: &'a T) {
*self |= &UInt::constant(*other);
}
}
@@ -118,7 +172,7 @@ mod tests {
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
R1CSVar,
};
use ark_ff::PrimeField;
@@ -149,28 +203,65 @@ mod tests {
Ok(())
}
fn uint_or_native<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = &a | &b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? | b), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_or() {
run_binary_exhaustive(uint_or::<u8, 8, Fr>).unwrap()
run_binary_exhaustive_both(uint_or::<u8, 8, Fr>, uint_or_native::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_or() {
run_binary_random::<1000, 16, _, _>(uint_or::<u16, 16, Fr>).unwrap()
run_binary_random_both::<1000, 16, _, _>(
uint_or::<u16, 16, Fr>,
uint_or_native::<u16, 16, Fr>,
)
.unwrap()
}
#[test]
fn u32_or() {
run_binary_random::<1000, 32, _, _>(uint_or::<u32, 32, Fr>).unwrap()
run_binary_random_both::<1000, 32, _, _>(
uint_or::<u32, 32, Fr>,
uint_or_native::<u32, 32, Fr>,
)
.unwrap()
}
#[test]
fn u64_or() {
run_binary_random::<1000, 64, _, _>(uint_or::<u64, 64, Fr>).unwrap()
run_binary_random_both::<1000, 64, _, _>(
uint_or::<u64, 64, Fr>,
uint_or_native::<u64, 64, Fr>,
)
.unwrap()
}
#[test]
fn u128_or() {
run_binary_random::<1000, 128, _, _>(uint_or::<u128, 128, Fr>).unwrap()
run_binary_random_both::<1000, 128, _, _>(
uint_or::<u128, 128, Fr>,
uint_or_native::<u128, 128, Fr>,
)
.unwrap()
}
}

View File

@@ -22,13 +22,37 @@ impl<const N: usize, T: PrimUInt, ConstraintF: Field> UInt<N, T, ConstraintF> {
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
pub fn rotate_right(&self, by: usize) -> Self {
let by = by % N;
let mut result = self.clone();
// `[T]::rotate_left` corresponds to a `rotate_right` of the bits.
result.bits.rotate_left(by);
result.value = self.value.map(|v| v.rotate_right(by as u32));
result.rotate_right_in_place(by);
result
}
/// Rotates `self` to the right *in place* by `by` steps, wrapping around.
///
/// # Examples
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?;
/// let b = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?;
///
/// a.rotate_right_in_place(8);
/// a.enforce_equal(&b)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
pub fn rotate_right_in_place(&mut self, by: usize) {
let by = by % N;
// `[T]::rotate_left` corresponds to a `rotate_right` of the bits.
self.bits.rotate_left(by);
self.value = self.value.map(|v| v.rotate_right(by as u32));
}
/// Rotates `self` to the left by `by` steps, wrapping around.
///
@@ -51,13 +75,37 @@ impl<const N: usize, T: PrimUInt, ConstraintF: Field> UInt<N, T, ConstraintF> {
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
pub fn rotate_left(&self, by: usize) -> Self {
let by = by % N;
let mut result = self.clone();
// `[T]::rotate_right` corresponds to a `rotate_left` of the bits.
result.bits.rotate_right(by);
result.value = self.value.map(|v| v.rotate_left(by as u32));
result.rotate_left_in_place(by);
result
}
/// Rotates `self` to the left *in place* by `by` steps, wrapping around.
///
/// # Examples
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?;
/// let b = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?;
///
/// a.rotate_left_in_place(8);
/// a.enforce_equal(&b)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
pub fn rotate_left_in_place(&mut self, by: usize) {
let by = by % N;
// `[T]::rotate_right` corresponds to a `rotate_left` of the bits.
self.bits.rotate_right(by);
self.value = self.value.map(|v| v.rotate_left(by as u32));
}
}
#[cfg(test)]

View File

@@ -99,7 +99,7 @@ mod tests {
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native},
uint::test_utils::{run_binary_exhaustive_native_only, run_binary_random_native_only},
R1CSVar,
};
use ark_ff::PrimeField;
@@ -129,26 +129,26 @@ mod tests {
#[test]
fn u8_shl() {
run_binary_exhaustive_with_native(uint_shl::<u8, 8, Fr>).unwrap()
run_binary_exhaustive_native_only(uint_shl::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_shl() {
run_binary_random_with_native::<1000, 16, _, _>(uint_shl::<u16, 16, Fr>).unwrap()
run_binary_random_native_only::<1000, 16, _, _>(uint_shl::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_shl() {
run_binary_random_with_native::<1000, 32, _, _>(uint_shl::<u32, 32, Fr>).unwrap()
run_binary_random_native_only::<1000, 32, _, _>(uint_shl::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_shl() {
run_binary_random_with_native::<1000, 64, _, _>(uint_shl::<u64, 64, Fr>).unwrap()
run_binary_random_native_only::<1000, 64, _, _>(uint_shl::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_shl() {
run_binary_random_with_native::<1000, 128, _, _>(uint_shl::<u128, 128, Fr>).unwrap()
run_binary_random_native_only::<1000, 128, _, _>(uint_shl::<u128, 128, Fr>).unwrap()
}
}

View File

@@ -99,7 +99,7 @@ mod tests {
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native},
uint::test_utils::{run_binary_exhaustive_native_only, run_binary_random_native_only},
R1CSVar,
};
use ark_ff::PrimeField;
@@ -129,26 +129,26 @@ mod tests {
#[test]
fn u8_shr() {
run_binary_exhaustive_with_native(uint_shr::<u8, 8, Fr>).unwrap()
run_binary_exhaustive_native_only(uint_shr::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_shr() {
run_binary_random_with_native::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
run_binary_random_native_only::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_shr() {
run_binary_random_with_native::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
run_binary_random_native_only::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_shr() {
run_binary_random_with_native::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
run_binary_random_native_only::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_shr() {
run_binary_random_with_native::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
run_binary_random_native_only::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
}
}

View File

@@ -39,6 +39,29 @@ pub(crate) fn test_binary_op_with_native<T: PrimUInt, const N: usize, F: PrimeFi
test(a, b)
}
pub(crate) fn run_binary_random_both<const ITERATIONS: usize, const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
test_native: impl Fn(UInt<N, T, F>, T) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
{
let mut rng = ark_std::test_rng();
for _ in 0..ITERATIONS {
for mode_a in modes() {
let a = T::rand(&mut rng);
for mode_b in modes() {
let b = T::rand(&mut rng);
test_binary_op(a, b, mode_a, mode_b, test)?;
test_binary_op_with_native(a, b, mode_a, test_native)?;
}
}
}
Ok(())
}
pub(crate) fn run_binary_random<const ITERATIONS: usize, const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
@@ -76,7 +99,25 @@ where
Ok(())
}
pub(crate) fn run_binary_random_with_native<const ITERATIONS: usize, const N: usize, T, F>(
pub(crate) fn run_binary_exhaustive_both<const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
test_native: impl Fn(UInt<N, T, F>, T) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
RangeInclusive<T>: Iterator<Item = T>,
{
for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) {
for (mode_b, b) in test_utils::combination(T::min_value()..=T::max_value()) {
test_binary_op(a, b, mode_a, mode_b, test)?;
test_binary_op_with_native(a, b, mode_a, test_native)?;
}
}
Ok(())
}
pub(crate) fn run_binary_random_native_only<const ITERATIONS: usize, const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, T) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
@@ -95,7 +136,7 @@ where
Ok(())
}
pub(crate) fn run_binary_exhaustive_with_native<const N: usize, T, F>(
pub(crate) fn run_binary_exhaustive_native_only<const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, T) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where

View File

@@ -7,11 +7,16 @@ use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _xor(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
result._xor_in_place(other)?;
Ok(result)
}
fn _xor_in_place(&mut self, other: &Self) -> Result<(), SynthesisError> {
for (a, b) in self.bits.iter_mut().zip(&other.bits) {
*a ^= b;
}
result.value = self.value.and_then(|a| Some(a ^ other.value?));
Ok(result)
self.value = self.value.and_then(|a| Some(a ^ other.value?));
Ok(())
}
}
@@ -49,8 +54,9 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a Self> for UInt<N, T,
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: &Self) -> Self::Output {
self._xor(&other).unwrap()
fn bitxor(mut self, other: &Self) -> Self::Output {
self._xor_in_place(&other).unwrap();
self
}
}
@@ -59,7 +65,7 @@ impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<UInt<N, T, F>> for &'a UI
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: UInt<N, T, F>) -> Self::Output {
self._xor(&other).unwrap()
other ^ self
}
}
@@ -68,7 +74,43 @@ impl<const N: usize, T: PrimUInt, F: Field> BitXor<Self> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: Self) -> Self::Output {
self._xor(&other).unwrap()
self ^ &other
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<T> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: T) -> Self::Output {
self ^ &UInt::constant(other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a T> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: &'a T) -> Self::Output {
self ^ &UInt::constant(*other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a T> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: &'a T) -> Self::Output {
self ^ UInt::constant(*other)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<T> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: T) -> Self::Output {
self ^ UInt::constant(other)
}
}
@@ -98,16 +140,28 @@ impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<Self> for UInt<N, T, F>
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor_assign(&mut self, other: Self) {
let result = self._xor(&other).unwrap();
*self = result;
self._xor_in_place(&other).unwrap();
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a Self> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor_assign(&mut self, other: &'a Self) {
let result = self._xor(other).unwrap();
*self = result;
self._xor_in_place(other).unwrap();
}
}
impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor_assign(&mut self, other: T) {
*self ^= Self::constant(other);
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a T> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor_assign(&mut self, other: &'a T) {
*self ^= Self::constant(*other);
}
}
@@ -117,7 +171,7 @@ mod tests {
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
uint::test_utils::{run_binary_exhaustive_both, run_binary_random_both},
R1CSVar,
};
use ark_ff::PrimeField;
@@ -148,28 +202,65 @@ mod tests {
Ok(())
}
fn uint_xor_native<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = &a ^ &b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? ^ b), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_xor() {
run_binary_exhaustive(uint_xor::<u8, 8, Fr>).unwrap()
run_binary_exhaustive_both(uint_xor::<u8, 8, Fr>, uint_xor_native::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_xor() {
run_binary_random::<1000, 16, _, _>(uint_xor::<u16, 16, Fr>).unwrap()
run_binary_random_both::<1000, 16, _, _>(
uint_xor::<u16, 16, Fr>,
uint_xor_native::<u16, 16, Fr>,
)
.unwrap()
}
#[test]
fn u32_xor() {
run_binary_random::<1000, 32, _, _>(uint_xor::<u32, 32, Fr>).unwrap()
run_binary_random_both::<1000, 32, _, _>(
uint_xor::<u32, 32, Fr>,
uint_xor_native::<u32, 32, Fr>,
)
.unwrap()
}
#[test]
fn u64_xor() {
run_binary_random::<1000, 64, _, _>(uint_xor::<u64, 64, Fr>).unwrap()
run_binary_random_both::<1000, 64, _, _>(
uint_xor::<u64, 64, Fr>,
uint_xor_native::<u64, 64, Fr>,
)
.unwrap()
}
#[test]
fn u128_xor() {
run_binary_random::<1000, 128, _, _>(uint_xor::<u128, 128, Fr>).unwrap()
run_binary_random_both::<1000, 128, _, _>(
uint_xor::<u128, 128, Fr>,
uint_xor_native::<u128, 128, Fr>,
)
.unwrap()
}
}