Refactor UInt{8,16,64,128} into one struct UInt (#121)

This commit is contained in:
Pratyush Mishra
2024-01-03 08:23:54 -05:00
committed by GitHub
parent 3cb9fdef00
commit d011859416
54 changed files with 5044 additions and 3081 deletions

50
src/uint/add/mod.rs Normal file
View File

@@ -0,0 +1,50 @@
use crate::fields::fp::FpVar;
use super::*;
mod saturating;
mod wrapping;
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
/// Adds up `operands`, returning the bit decomposition of the result, along with
/// the value of the result. If all the operands are constant, then the bit decomposition
/// is empty, and the value is the constant value of the result.
///
/// # Panics
///
/// This method panics if the result of addition could possibly exceed the field size.
#[tracing::instrument(target = "r1cs", skip(operands, adder))]
fn add_many_helper(
operands: &[Self],
adder: impl Fn(T, T) -> T,
) -> Result<(Vec<Boolean<F>>, Option<T>), SynthesisError> {
// Bounds on `N` to avoid overflows
assert!(operands.len() >= 1);
let max_value_size = N as u32 + ark_std::log2(operands.len());
assert!(max_value_size <= F::MODULUS_BIT_SIZE);
if operands.len() == 1 {
return Ok((operands[0].bits.to_vec(), operands[0].value));
}
// Compute the value of the result.
let mut value = Some(T::zero());
for op in operands {
value = value.and_then(|v| Some(adder(v, op.value?)));
}
if operands.is_constant() {
// If all operands are constant, then the result is also constant.
// In this case, we can return early.
return Ok((Vec::new(), value));
}
// Compute the full (non-wrapped) sum of the operands.
let result = operands
.iter()
.map(|op| Boolean::le_bits_to_fp(&op.bits).unwrap())
.sum::<FpVar<_>>();
let (result, _) = result.to_bits_le_with_top_bits_zero(max_value_size as usize)?;
Ok((result, value))
}
}

117
src/uint/add/saturating.rs Normal file
View File

@@ -0,0 +1,117 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use crate::uint::*;
use crate::{boolean::Boolean, R1CSVar};
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
/// Compute `*self = self.wrapping_add(other)`.
pub fn saturating_add_in_place(&mut self, other: &Self) {
let result = Self::saturating_add_many(&[self.clone(), other.clone()]).unwrap();
*self = result;
}
/// Compute `self.wrapping_add(other)`.
pub fn saturating_add(&self, other: &Self) -> Self {
let mut result = self.clone();
result.saturating_add_in_place(other);
result
}
/// Perform wrapping addition of `operands`.
/// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`.
///
/// The user must ensure that overflow does not occur.
#[tracing::instrument(target = "r1cs", skip(operands))]
pub fn saturating_add_many(operands: &[Self]) -> Result<Self, SynthesisError>
where
F: PrimeField,
{
let (sum_bits, value) = Self::add_many_helper(operands, |a, b| a.saturating_add(b))?;
if operands.is_constant() {
// If all operands are constant, then the result is also constant.
// In this case, we can return early.
Ok(UInt::constant(value.unwrap()))
} else if sum_bits.len() == N {
// No overflow occurred.
Ok(UInt::from_bits_le(&sum_bits))
} else {
// Split the sum into the bottom `N` bits and the top bits.
let (bottom_bits, top_bits) = sum_bits.split_at(N);
// Construct a candidate result assuming that no overflow occurred.
let bits = TryFrom::try_from(bottom_bits.to_vec()).unwrap();
let candidate_result = UInt { bits, value };
// Check if any of the top bits is set.
// If any of them is set, then overflow occurred.
let overflow_occurred = Boolean::kary_or(&top_bits)?;
// If overflow occurred, return the maximum value.
overflow_occurred.select(&Self::MAX, &candidate_result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_saturating_add<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = a.saturating_add(&b);
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::new_variable(
cs.clone(),
|| Ok(a.value()?.saturating_add(b.value()?)),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_saturating_add() {
run_binary_exhaustive(uint_saturating_add::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_saturating_add() {
run_binary_random::<1000, 16, _, _>(uint_saturating_add::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_saturating_add() {
run_binary_random::<1000, 32, _, _>(uint_saturating_add::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_saturating_add() {
run_binary_random::<1000, 64, _, _>(uint_saturating_add::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_saturating_add() {
run_binary_random::<1000, 128, _, _>(uint_saturating_add::<u128, 128, Fr>).unwrap()
}
}

106
src/uint/add/wrapping.rs Normal file
View File

@@ -0,0 +1,106 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use crate::uint::*;
use crate::R1CSVar;
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
/// Compute `*self = self.wrapping_add(other)`.
pub fn wrapping_add_in_place(&mut self, other: &Self) {
let result = Self::wrapping_add_many(&[self.clone(), other.clone()]).unwrap();
*self = result;
}
/// Compute `self.wrapping_add(other)`.
pub fn wrapping_add(&self, other: &Self) -> Self {
let mut result = self.clone();
result.wrapping_add_in_place(other);
result
}
/// Perform wrapping addition of `operands`.
/// Computes `operands[0].wrapping_add(operands[1]).wrapping_add(operands[2])...`.
///
/// The user must ensure that overflow does not occur.
#[tracing::instrument(target = "r1cs", skip(operands))]
pub fn wrapping_add_many(operands: &[Self]) -> Result<Self, SynthesisError>
where
F: PrimeField,
{
let (mut sum_bits, value) = Self::add_many_helper(operands, |a, b| a.wrapping_add(&b))?;
if operands.is_constant() {
// If all operands are constant, then the result is also constant.
// In this case, we can return early.
Ok(UInt::constant(value.unwrap()))
} else {
sum_bits.truncate(N);
Ok(UInt {
bits: sum_bits.try_into().unwrap(),
value,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_wrapping_add<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = a.wrapping_add(&b);
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::new_variable(
cs.clone(),
|| Ok(a.value()?.wrapping_add(&b.value()?)),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_wrapping_add() {
run_binary_exhaustive(uint_wrapping_add::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_wrapping_add() {
run_binary_random::<1000, 16, _, _>(uint_wrapping_add::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_wrapping_add() {
run_binary_random::<1000, 32, _, _>(uint_wrapping_add::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_wrapping_add() {
run_binary_random::<1000, 64, _, _>(uint_wrapping_add::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_wrapping_add() {
run_binary_random::<1000, 128, _, _>(uint_wrapping_add::<u128, 128, Fr>).unwrap()
}
}

263
src/uint/and.rs Normal file
View File

@@ -0,0 +1,263 @@
use ark_ff::Field;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::BitAnd, ops::BitAndAssign};
use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _and(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
*a &= b;
}
result.value = self.value.and_then(|a| Some(a & other.value?));
Ok(result)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<Self> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Outputs `self & other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?;
///
/// (a & &b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: Self) -> Self::Output {
self._and(other).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<&'a Self> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Outputs `self & other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?;
///
/// (a & &b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: &Self) -> Self::Output {
self._and(&other).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAnd<UInt<N, T, F>> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Outputs `self & other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?;
///
/// (a & &b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: UInt<N, T, F>) -> Self::Output {
self._and(&other).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: Field> BitAnd<Self> for UInt<N, T, F> {
type Output = Self;
/// Outputs `self & other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?;
///
/// (a & &b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand(self, other: Self) -> Self::Output {
self._and(&other).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: Field> BitAndAssign<Self> for UInt<N, T, F> {
/// Sets `self = self & other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?;
///
/// a &= &b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: Self) {
let result = self._and(&other).unwrap();
*self = result;
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitAndAssign<&'a Self> for UInt<N, T, F> {
/// Sets `self = self & other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 & 17))?;
///
/// a &= &b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitand_assign(&mut self, other: &'a Self) {
let result = self._and(other).unwrap();
*self = result;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_and<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = &a & &b;
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()? & b.value()?),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_and() {
run_binary_exhaustive(uint_and::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_and() {
run_binary_random::<1000, 16, _, _>(uint_and::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_and() {
run_binary_random::<1000, 32, _, _>(uint_and::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_and() {
run_binary_random::<1000, 64, _, _>(uint_and::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_and() {
run_binary_random::<1000, 128, _, _>(uint_and::<u128, 128, Fr>).unwrap()
}
}

218
src/uint/cmp.rs Normal file
View File

@@ -0,0 +1,218 @@
use crate::cmp::CmpGadget;
use super::*;
impl<const N: usize, T: PrimUInt, F: PrimeField + From<T>> CmpGadget<F> for UInt<N, T, F> {
fn is_ge(&self, other: &Self) -> Result<Boolean<F>, SynthesisError> {
if N + 1 < ((F::MODULUS_BIT_SIZE - 1) as usize) {
let a = self.to_fp()?;
let b = other.to_fp()?;
let (bits, _) = (a - b + F::from(T::max_value()) + F::one())
.to_bits_le_with_top_bits_zero(N + 1)?;
Ok(bits.last().unwrap().clone())
} else {
unimplemented!("bit sizes larger than modulus size not yet supported")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_gt<T: PrimUInt, const N: usize, F: PrimeField + From<T>>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let computed = a.is_gt(&b)?;
let expected =
Boolean::new_variable(cs.clone(), || Ok(a.value()? > b.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_lt<T: PrimUInt, const N: usize, F: PrimeField + From<T>>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let computed = a.is_lt(&b)?;
let expected =
Boolean::new_variable(cs.clone(), || Ok(a.value()? < b.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_ge<T: PrimUInt, const N: usize, F: PrimeField + From<T>>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let computed = a.is_ge(&b)?;
let expected =
Boolean::new_variable(cs.clone(), || Ok(a.value()? >= b.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_le<T: PrimUInt, const N: usize, F: PrimeField + From<T>>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let computed = a.is_le(&b)?;
let expected =
Boolean::new_variable(cs.clone(), || Ok(a.value()? <= b.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_gt() {
run_binary_exhaustive(uint_gt::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_gt() {
run_binary_random::<1000, 16, _, _>(uint_gt::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_gt() {
run_binary_random::<1000, 32, _, _>(uint_gt::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_gt() {
run_binary_random::<1000, 64, _, _>(uint_gt::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_gt() {
run_binary_random::<1000, 128, _, _>(uint_gt::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_lt() {
run_binary_exhaustive(uint_lt::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_lt() {
run_binary_random::<1000, 16, _, _>(uint_lt::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_lt() {
run_binary_random::<1000, 32, _, _>(uint_lt::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_lt() {
run_binary_random::<1000, 64, _, _>(uint_lt::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_lt() {
run_binary_random::<1000, 128, _, _>(uint_lt::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_le() {
run_binary_exhaustive(uint_le::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_le() {
run_binary_random::<1000, 16, _, _>(uint_le::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_le() {
run_binary_random::<1000, 32, _, _>(uint_le::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_le() {
run_binary_random::<1000, 64, _, _>(uint_le::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_le() {
run_binary_random::<1000, 128, _, _>(uint_le::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_ge() {
run_binary_exhaustive(uint_ge::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_ge() {
run_binary_random::<1000, 16, _, _>(uint_ge::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_ge() {
run_binary_random::<1000, 32, _, _>(uint_ge::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_ge() {
run_binary_random::<1000, 64, _, _>(uint_ge::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_ge() {
run_binary_random::<1000, 128, _, _>(uint_ge::<u128, 128, Fr>).unwrap()
}
}

129
src/uint/convert.rs Normal file
View File

@@ -0,0 +1,129 @@
use crate::convert::*;
use crate::fields::fp::FpVar;
use super::*;
impl<const N: usize, F: Field, T: PrimUInt> UInt<N, T, F> {
/// Converts `self` into a field element. The elements comprising `self` are
/// interpreted as a little-endian bit order representation of a field element.
///
/// # Panics
/// Assumes that `N` is equal to at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise.
pub fn to_fp(&self) -> Result<FpVar<F>, SynthesisError>
where
F: PrimeField,
{
assert!(N <= F::MODULUS_BIT_SIZE as usize - 1);
Boolean::le_bits_to_fp(&self.bits)
}
/// Converts a field element into its little-endian bit order representation.
///
/// # Panics
///
/// Assumes that `N` is at most the number of bits in `F::MODULUS_BIT_SIZE - 1`, and panics otherwise.
pub fn from_fp(other: &FpVar<F>) -> Result<(Self, FpVar<F>), SynthesisError>
where
F: PrimeField,
{
let (bits, rest) = other.to_bits_le_with_top_bits_zero(N)?;
let result = Self::from_bits_le(&bits);
Ok((result, rest))
}
/// Converts a little-endian byte order representation of bits into a
/// `UInt`.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let var = UInt8::new_witness(cs.clone(), || Ok(128))?;
///
/// let f = Boolean::FALSE;
/// let t = Boolean::TRUE;
///
/// // Construct [0, 0, 0, 0, 0, 0, 0, 1]
/// let mut bits = vec![f.clone(); 7];
/// bits.push(t);
///
/// let mut c = UInt8::from_bits_le(&bits);
/// var.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs")]
pub fn from_bits_le(bits: &[Boolean<F>]) -> Self {
assert_eq!(bits.len(), N);
let bits = <&[Boolean<F>; N]>::try_from(bits).unwrap().clone();
let value_exists = bits.iter().all(|b| b.value().is_ok());
let mut value = T::zero();
for (i, b) in bits.iter().enumerate() {
if let Ok(b) = b.value() {
value = value + (T::from(b as u8).unwrap() << i);
}
}
let value = value_exists.then_some(value);
Self { bits, value }
}
}
impl<const N: usize, T: PrimUInt, F: Field> ToBitsGadget<F> for UInt<N, T, F> {
fn to_bits_le(&self) -> Result<Vec<Boolean<F>>, SynthesisError> {
Ok(self.bits.to_vec())
}
}
impl<const N: usize, T: PrimUInt, F: Field> ToBitsGadget<F> for [UInt<N, T, F>] {
/// Interprets `self` as an integer, and outputs the little-endian
/// bit-wise decomposition of that integer.
fn to_bits_le(&self) -> Result<Vec<Boolean<F>>, SynthesisError> {
let bits = self.iter().flat_map(|b| &b.bits).cloned().collect();
Ok(bits)
}
}
/*****************************************************************************************/
/********************************* Conversions to bytes. *********************************/
/*****************************************************************************************/
impl<const N: usize, T: PrimUInt, ConstraintF: Field> ToBytesGadget<ConstraintF>
for UInt<N, T, ConstraintF>
{
#[tracing::instrument(target = "r1cs", skip(self))]
fn to_bytes(&self) -> Result<Vec<UInt8<ConstraintF>>, SynthesisError> {
Ok(self
.to_bits_le()?
.chunks(8)
.map(UInt8::from_bits_le)
.collect())
}
}
impl<const N: usize, T: PrimUInt, F: Field> ToBytesGadget<F> for [UInt<N, T, F>] {
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
let mut bytes = Vec::with_capacity(self.len() * (N / 8));
for elem in self {
bytes.extend_from_slice(&elem.to_bytes()?);
}
Ok(bytes)
}
}
impl<const N: usize, T: PrimUInt, F: Field> ToBytesGadget<F> for Vec<UInt<N, T, F>> {
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
self.as_slice().to_bytes()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> ToBytesGadget<F> for &'a [UInt<N, T, F>] {
fn to_bytes(&self) -> Result<Vec<UInt8<F>>, SynthesisError> {
(*self).to_bytes()
}
}

173
src/uint/eq.rs Normal file
View File

@@ -0,0 +1,173 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use ark_std::vec::Vec;
use crate::boolean::Boolean;
use crate::eq::EqGadget;
use super::*;
impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> EqGadget<ConstraintF>
for UInt<N, T, ConstraintF>
{
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn is_eq(&self, other: &Self) -> Result<Boolean<ConstraintF>, SynthesisError> {
let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
let chunks_are_eq = self
.bits
.chunks(chunk_size)
.zip(other.bits.chunks(chunk_size))
.map(|(a, b)| {
let a = Boolean::le_bits_to_fp(a)?;
let b = Boolean::le_bits_to_fp(b)?;
a.is_eq(&b)
})
.collect::<Result<Vec<_>, _>>()?;
Boolean::kary_and(&chunks_are_eq)
}
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_equal(
&self,
other: &Self,
condition: &Boolean<ConstraintF>,
) -> Result<(), SynthesisError> {
let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
for (a, b) in self
.bits
.chunks(chunk_size)
.zip(other.bits.chunks(chunk_size))
{
let a = Boolean::le_bits_to_fp(a)?;
let b = Boolean::le_bits_to_fp(b)?;
a.conditional_enforce_equal(&b, condition)?;
}
Ok(())
}
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn conditional_enforce_not_equal(
&self,
other: &Self,
condition: &Boolean<ConstraintF>,
) -> Result<(), SynthesisError> {
let chunk_size = usize::try_from(ConstraintF::MODULUS_BIT_SIZE - 1).unwrap();
for (a, b) in self
.bits
.chunks(chunk_size)
.zip(other.bits.chunks(chunk_size))
{
let a = Boolean::le_bits_to_fp(a)?;
let b = Boolean::le_bits_to_fp(b)?;
a.conditional_enforce_not_equal(&b, condition)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_eq<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = a.is_eq(&b)?;
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
Boolean::new_variable(cs.clone(), || Ok(a.value()? == b.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
fn uint_neq<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = a.is_neq(&b)?;
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
Boolean::new_variable(cs.clone(), || Ok(a.value()? != b.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_eq() {
run_binary_exhaustive(uint_eq::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_eq() {
run_binary_random::<1000, 16, _, _>(uint_eq::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_eq() {
run_binary_random::<1000, 32, _, _>(uint_eq::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_eq() {
run_binary_random::<1000, 64, _, _>(uint_eq::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_eq() {
run_binary_random::<1000, 128, _, _>(uint_eq::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_neq() {
run_binary_exhaustive(uint_neq::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_neq() {
run_binary_random::<1000, 16, _, _>(uint_neq::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_neq() {
run_binary_random::<1000, 32, _, _>(uint_neq::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_neq() {
run_binary_random::<1000, 64, _, _>(uint_neq::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_neq() {
run_binary_random::<1000, 128, _, _>(uint_neq::<u128, 128, Fr>).unwrap()
}
}

160
src/uint/mod.rs Normal file
View File

@@ -0,0 +1,160 @@
use ark_ff::{Field, PrimeField};
use core::{borrow::Borrow, convert::TryFrom, fmt::Debug};
use ark_relations::r1cs::{ConstraintSystemRef, Namespace, SynthesisError};
use crate::{boolean::Boolean, prelude::*, Assignment, Vec};
mod add;
mod and;
mod cmp;
mod convert;
mod eq;
mod not;
mod or;
mod rotate;
mod select;
mod shl;
mod shr;
mod xor;
#[doc(hidden)]
pub mod prim_uint;
pub use prim_uint::*;
#[cfg(test)]
pub(crate) mod test_utils;
/// This struct represent an unsigned `N` bit integer as a sequence of `N` [`Boolean`]s.
#[derive(Clone, Debug)]
pub struct UInt<const N: usize, T: PrimUInt, F: Field> {
#[doc(hidden)]
pub bits: [Boolean<F>; N],
#[doc(hidden)]
pub value: Option<T>,
}
impl<const N: usize, T: PrimUInt, F: Field> R1CSVar<F> for UInt<N, T, F> {
type Value = T;
fn cs(&self) -> ConstraintSystemRef<F> {
self.bits.as_ref().cs()
}
fn value(&self) -> Result<Self::Value, SynthesisError> {
let mut value = T::zero();
for (i, bit) in self.bits.iter().enumerate() {
value = value + (T::from(bit.value()? as u8).unwrap() << i);
}
debug_assert_eq!(self.value, Some(value));
Ok(value)
}
}
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
pub const MAX: Self = Self {
bits: [Boolean::TRUE; N],
value: Some(T::MAX),
};
/// Construct a constant [`UInt`] from the native unsigned integer type.
///
/// This *does not* create new variables or constraints.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let var = UInt8::new_witness(cs.clone(), || Ok(2))?;
///
/// let constant = UInt8::constant(2);
/// var.enforce_equal(&constant)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
pub fn constant(value: T) -> Self {
let mut bits = [Boolean::FALSE; N];
let mut bit_values = value;
for i in 0..N {
bits[i] = Boolean::constant((bit_values & T::one()) == T::one());
bit_values = bit_values >> 1u8;
}
Self {
bits,
value: Some(value),
}
}
/// Construct a constant vector of [`UInt`] from a vector of the native type
///
/// This *does not* create any new variables or constraints.
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let var = vec![UInt8::new_witness(cs.clone(), || Ok(2))?];
///
/// let constant = UInt8::constant_vec(&[2]);
/// var.enforce_equal(&constant)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
pub fn constant_vec(values: &[T]) -> Vec<Self> {
values.iter().map(|v| Self::constant(*v)).collect()
}
/// Allocates a slice of `uN`'s as private witnesses.
pub fn new_witness_vec(
cs: impl Into<Namespace<F>>,
values: &[impl Into<Option<T>> + Copy],
) -> Result<Vec<Self>, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
let mut output_vec = Vec::with_capacity(values.len());
for value in values {
let byte: Option<T> = Into::into(*value);
output_vec.push(Self::new_witness(cs.clone(), || byte.get())?);
}
Ok(output_vec)
}
}
impl<const N: usize, T: PrimUInt, ConstraintF: Field> AllocVar<T, ConstraintF>
for UInt<N, T, ConstraintF>
{
fn new_variable<S: Borrow<T>>(
cs: impl Into<Namespace<ConstraintF>>,
f: impl FnOnce() -> Result<S, SynthesisError>,
mode: AllocationMode,
) -> Result<Self, SynthesisError> {
let ns = cs.into();
let cs = ns.cs();
let value = f().map(|f| *f.borrow()).ok();
let mut values = [None; N];
if let Some(val) = value {
values
.iter_mut()
.enumerate()
.for_each(|(i, v)| *v = Some(((val >> i) & T::one()) == T::one()));
}
let mut bits = [Boolean::FALSE; N];
for (b, v) in bits.iter_mut().zip(&values) {
*b = Boolean::new_variable(cs.clone(), || v.get(), mode)?;
}
Ok(Self { bits, value })
}
}

131
src/uint/not.rs Normal file
View File

@@ -0,0 +1,131 @@
use ark_ff::Field;
use ark_relations::r1cs::SynthesisError;
use ark_std::ops::Not;
use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _not(&self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for a in &mut result.bits {
*a = !&*a
}
result.value = self.value.map(Not::not);
Ok(result)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> Not for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Outputs `!self`.
///
/// If `self` is a constant, then this method *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(2))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(!2))?;
///
/// (!a).enforce_equal(&b)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
fn not(self) -> Self::Output {
self._not().unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> Not for UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Outputs `!self`.
///
/// If `self` is a constant, then this method *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(2))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(!2))?;
///
/// (!a).enforce_equal(&b)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
fn not(self) -> Self::Output {
self._not().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_unary_exhaustive, run_unary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_not<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let computed = !&a;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(!a.value()?), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_not() {
run_unary_exhaustive(uint_not::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_not() {
run_unary_random::<1000, 16, _, _>(uint_not::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_not() {
run_unary_random::<1000, 32, _, _>(uint_not::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_not() {
run_unary_random::<1000, 64, _, _>(uint_not::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128() {
run_unary_random::<1000, 128, _, _>(uint_not::<u128, 128, Fr>).unwrap()
}
}

176
src/uint/or.rs Normal file
View File

@@ -0,0 +1,176 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::BitOr, ops::BitOrAssign};
use super::{PrimUInt, UInt};
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
fn _or(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
*a |= b;
}
result.value = self.value.and_then(|a| Some(a | other.value?));
Ok(result)
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Output `self | other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 | 17))?;
///
/// (a | b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: Self) -> Self::Output {
self._or(other).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<&'a Self> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: &Self) -> Self::Output {
self._or(&other).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOr<UInt<N, T, F>> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: UInt<N, T, F>) -> Self::Output {
self._or(&other).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField> BitOr<Self> for UInt<N, T, F> {
type Output = Self;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor(self, other: Self) -> Self::Output {
self._or(&other).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<Self> for UInt<N, T, F> {
/// Sets `self = self | other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 | 17))?;
///
/// a |= b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor_assign(&mut self, other: Self) {
let result = self._or(&other).unwrap();
*self = result;
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField> BitOrAssign<&'a Self> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitor_assign(&mut self, other: &'a Self) {
let result = self._or(other).unwrap();
*self = result;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_or<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = &a | &b;
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()? | b.value()?),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_or() {
run_binary_exhaustive(uint_or::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_or() {
run_binary_random::<1000, 16, _, _>(uint_or::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_or() {
run_binary_random::<1000, 32, _, _>(uint_or::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_or() {
run_binary_random::<1000, 64, _, _>(uint_or::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_or() {
run_binary_random::<1000, 128, _, _>(uint_or::<u128, 128, Fr>).unwrap()
}
}

175
src/uint/prim_uint.rs Normal file
View File

@@ -0,0 +1,175 @@
use core::ops::{Shl, ShlAssign, Shr, ShrAssign};
use core::usize;
#[doc(hidden)]
// Adapted from <https://github.com/rust-num/num-traits/pull/224>
pub trait PrimUInt:
core::fmt::Debug
+ num_traits::PrimInt
+ num_traits::WrappingAdd
+ num_traits::SaturatingAdd
+ Shl<usize, Output = Self>
+ Shl<u8, Output = Self>
+ Shl<u16, Output = Self>
+ Shl<u32, Output = Self>
+ Shl<u64, Output = Self>
+ Shl<u128, Output = Self>
+ Shr<usize, Output = Self>
+ Shr<u8, Output = Self>
+ Shr<u16, Output = Self>
+ Shr<u32, Output = Self>
+ Shr<u64, Output = Self>
+ Shr<u128, Output = Self>
+ ShlAssign<usize>
+ ShlAssign<u8>
+ ShlAssign<u16>
+ ShlAssign<u32>
+ ShlAssign<u64>
+ ShlAssign<u128>
+ ShrAssign<usize>
+ ShrAssign<u8>
+ ShrAssign<u16>
+ ShrAssign<u32>
+ ShrAssign<u64>
+ ShrAssign<u128>
+ Into<u128>
+ _private::Sealed
+ ark_std::UniformRand
{
type Bytes: NumBytes;
const MAX: Self;
#[doc(hidden)]
const MAX_VALUE_BIT_DECOMP: &'static [bool];
/// Return the memory representation of this number as a byte array in little-endian byte order.
///
/// # Examples
///
/// ```
/// use ark_r1cs_std::uint::PrimUInt;
///
/// let bytes = PrimUInt::to_le_bytes(&0x12345678u32);
/// assert_eq!(bytes, [0x78, 0x56, 0x34, 0x12]);
/// ```
fn to_le_bytes(&self) -> Self::Bytes;
/// Return the memory representation of this number as a byte array in big-endian byte order.
///
/// # Examples
///
/// ```
/// use ark_r1cs_std::uint::PrimUInt;
///
/// let bytes = PrimUInt::to_be_bytes(&0x12345678u32);
/// assert_eq!(bytes, [0x12, 0x34, 0x56, 0x78]);
/// ```
fn to_be_bytes(&self) -> Self::Bytes;
}
impl PrimUInt for u8 {
const MAX: Self = u8::MAX;
const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 8];
type Bytes = [u8; 1];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
u8::to_le_bytes(*self)
}
#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
u8::to_be_bytes(*self)
}
}
impl PrimUInt for u16 {
const MAX: Self = u16::MAX;
const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 16];
type Bytes = [u8; 2];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
u16::to_le_bytes(*self)
}
#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
u16::to_be_bytes(*self)
}
}
impl PrimUInt for u32 {
const MAX: Self = u32::MAX;
const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 32];
type Bytes = [u8; 4];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
u32::to_le_bytes(*self)
}
#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
u32::to_be_bytes(*self)
}
}
impl PrimUInt for u64 {
const MAX: Self = u64::MAX;
const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 64];
type Bytes = [u8; 8];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
u64::to_le_bytes(*self)
}
#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
u64::to_be_bytes(*self)
}
}
impl PrimUInt for u128 {
const MAX: Self = u128::MAX;
const MAX_VALUE_BIT_DECOMP: &'static [bool] = &[true; 128];
type Bytes = [u8; 16];
#[inline]
fn to_le_bytes(&self) -> Self::Bytes {
u128::to_le_bytes(*self)
}
#[inline]
fn to_be_bytes(&self) -> Self::Bytes {
u128::to_be_bytes(*self)
}
}
#[doc(hidden)]
pub trait NumBytes:
core::fmt::Debug
+ AsRef<[u8]>
+ AsMut<[u8]>
+ PartialEq
+ Eq
+ PartialOrd
+ Ord
+ core::hash::Hash
+ core::borrow::Borrow<[u8]>
+ core::borrow::BorrowMut<[u8]>
{
}
#[doc(hidden)]
impl<const N: usize> NumBytes for [u8; N] {}
mod _private {
pub trait Sealed {}
impl Sealed for u8 {}
impl Sealed for u16 {}
impl Sealed for u32 {}
impl Sealed for u64 {}
impl Sealed for u128 {}
}

174
src/uint/rotate.rs Normal file
View File

@@ -0,0 +1,174 @@
use super::*;
impl<const N: usize, T: PrimUInt, ConstraintF: Field> UInt<N, T, ConstraintF> {
/// Rotates `self` to the right by `by` steps, wrapping around.
///
/// # Examples
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?;
/// let b = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?;
///
/// a.rotate_right(8).enforce_equal(&b)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
pub fn rotate_right(&self, by: usize) -> Self {
let by = by % N;
let mut result = self.clone();
// `[T]::rotate_left` corresponds to a `rotate_right` of the bits.
result.bits.rotate_left(by);
result.value = self.value.map(|v| v.rotate_right(by as u32));
result
}
/// Rotates `self` to the left by `by` steps, wrapping around.
///
/// # Examples
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt32::new_witness(cs.clone(), || Ok(0x10000b3))?;
/// let b = UInt32::new_witness(cs.clone(), || Ok(0xb301u32))?;
///
/// a.rotate_left(8).enforce_equal(&b)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self))]
pub fn rotate_left(&self, by: usize) -> Self {
let by = by % N;
let mut result = self.clone();
// `[T]::rotate_right` corresponds to a `rotate_left` of the bits.
result.bits.rotate_right(by);
result.value = self.value.map(|v| v.rotate_left(by as u32));
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_unary_exhaustive, run_unary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_rotate_left<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
for shift in 0..N {
let computed = a.rotate_left(shift);
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()?.rotate_left(shift as u32)),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
}
Ok(())
}
fn uint_rotate_right<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
for shift in 0..N {
let computed = a.rotate_right(shift);
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()?.rotate_right(shift as u32)),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
}
Ok(())
}
#[test]
fn u8_rotate_left() {
run_unary_exhaustive(uint_rotate_left::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_rotate_left() {
run_unary_random::<1000, 16, _, _>(uint_rotate_left::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_rotate_left() {
run_unary_random::<1000, 32, _, _>(uint_rotate_left::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_rotate_left() {
run_unary_random::<200, 64, _, _>(uint_rotate_left::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_rotate_left() {
run_unary_random::<100, 128, _, _>(uint_rotate_left::<u128, 128, Fr>).unwrap()
}
#[test]
fn u8_rotate_right() {
run_unary_exhaustive(uint_rotate_right::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_rotate_right() {
run_unary_random::<1000, 16, _, _>(uint_rotate_right::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_rotate_right() {
run_unary_random::<1000, 32, _, _>(uint_rotate_right::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_rotate_right() {
run_unary_random::<200, 64, _, _>(uint_rotate_right::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_rotate_right() {
run_unary_random::<100, 128, _, _>(uint_rotate_right::<u128, 128, Fr>).unwrap()
}
}

98
src/uint/select.rs Normal file
View File

@@ -0,0 +1,98 @@
use super::*;
use crate::select::CondSelectGadget;
impl<const N: usize, T: PrimUInt, ConstraintF: PrimeField> CondSelectGadget<ConstraintF>
for UInt<N, T, ConstraintF>
{
#[tracing::instrument(target = "r1cs", skip(cond, true_value, false_value))]
fn conditionally_select(
cond: &Boolean<ConstraintF>,
true_value: &Self,
false_value: &Self,
) -> Result<Self, SynthesisError> {
let selected_bits = true_value
.bits
.iter()
.zip(&false_value.bits)
.map(|(t, f)| cond.select(t, f));
let mut bits = [Boolean::FALSE; N];
for (result, new) in bits.iter_mut().zip(selected_bits) {
*result = new?;
}
let value = cond.value().ok().and_then(|cond| {
if cond {
true_value.value().ok()
} else {
false_value.value().ok()
}
});
Ok(Self { bits, value })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_select<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
for cond in [true, false] {
let expected = UInt::new_variable(
cs.clone(),
|| Ok(if cond { a.value()? } else { b.value()? }),
expected_mode,
)?;
let cond = Boolean::new_variable(cs.clone(), || Ok(cond), expected_mode)?;
let computed = cond.select(&a, &b)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
}
Ok(())
}
#[test]
fn u8_select() {
run_binary_exhaustive(uint_select::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_select() {
run_binary_random::<1000, 16, _, _>(uint_select::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_select() {
run_binary_random::<1000, 32, _, _>(uint_select::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_select() {
run_binary_random::<1000, 64, _, _>(uint_select::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_select() {
run_binary_random::<1000, 128, _, _>(uint_select::<u128, 128, Fr>).unwrap()
}
}

154
src/uint/shl.rs Normal file
View File

@@ -0,0 +1,154 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::Shl, ops::ShlAssign};
use crate::boolean::Boolean;
use super::{PrimUInt, UInt};
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
fn _shl_u128(&self, other: u128) -> Result<Self, SynthesisError> {
if other < N as u128 {
let mut bits = [Boolean::FALSE; N];
for (a, b) in bits[other as usize..].iter_mut().zip(&self.bits) {
*a = b.clone();
}
let value = self.value.and_then(|a| Some(a << other));
Ok(Self { bits, value })
} else {
panic!("attempt to shift left with overflow")
}
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl<T2> for UInt<N, T, F> {
type Output = Self;
/// Output `self << other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1u8;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?;
///
/// (a << b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shl(self, other: T2) -> Self::Output {
self._shl_u128(other.into()).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shl<T2> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shl(self, other: T2) -> Self::Output {
self._shl_u128(other.into()).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShlAssign<T2> for UInt<N, T, F> {
/// Sets `self = self << other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1u8;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 << 1))?;
///
/// a <<= b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shl_assign(&mut self, other: T2) {
let result = self._shl_u128(other.into()).unwrap();
*self = result;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_shl<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let b = b.into() % (N as u128);
let computed = &a << b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? << b), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_shl() {
run_binary_exhaustive_with_native(uint_shl::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_shl() {
run_binary_random_with_native::<1000, 16, _, _>(uint_shl::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_shl() {
run_binary_random_with_native::<1000, 32, _, _>(uint_shl::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_shl() {
run_binary_random_with_native::<1000, 64, _, _>(uint_shl::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_shl() {
run_binary_random_with_native::<1000, 128, _, _>(uint_shl::<u128, 128, Fr>).unwrap()
}
}

154
src/uint/shr.rs Normal file
View File

@@ -0,0 +1,154 @@
use ark_ff::PrimeField;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::Shr, ops::ShrAssign};
use crate::boolean::Boolean;
use super::{PrimUInt, UInt};
impl<const N: usize, T: PrimUInt, F: PrimeField> UInt<N, T, F> {
fn _shr_u128(&self, other: u128) -> Result<Self, SynthesisError> {
if other < N as u128 {
let mut bits = [Boolean::FALSE; N];
for (a, b) in bits.iter_mut().zip(&self.bits[other as usize..]) {
*a = b.clone();
}
let value = self.value.and_then(|a| Some(a >> other));
Ok(Self { bits, value })
} else {
panic!("attempt to shift right with overflow")
}
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for UInt<N, T, F> {
type Output = Self;
/// Output `self >> other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1u8;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?;
///
/// (a >> b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shr(self, other: T2) -> Self::Output {
self._shr_u128(other.into()).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> Shr<T2> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shr(self, other: T2) -> Self::Output {
self._shr_u128(other.into()).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: PrimeField, T2: PrimUInt> ShrAssign<T2> for UInt<N, T, F> {
/// Sets `self = self >> other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = 1u8;
/// let c = UInt8::new_witness(cs.clone(), || Ok(16 >> 1))?;
///
/// a >>= b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn shr_assign(&mut self, other: T2) {
let result = self._shr_u128(other.into()).unwrap();
*self = result;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive_with_native, run_binary_random_with_native},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_shr<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: T,
) -> Result<(), SynthesisError> {
let cs = a.cs();
let b = b.into() % (N as u128);
let computed = &a >> b;
let expected_mode = if a.is_constant() {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected =
UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a.value()? >> b), expected_mode)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !a.is_constant() {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_shr() {
run_binary_exhaustive_with_native(uint_shr::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_shr() {
run_binary_random_with_native::<1000, 16, _, _>(uint_shr::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_shr() {
run_binary_random_with_native::<1000, 32, _, _>(uint_shr::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_shr() {
run_binary_random_with_native::<1000, 64, _, _>(uint_shr::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_shr() {
run_binary_random_with_native::<1000, 128, _, _>(uint_shr::<u128, 128, Fr>).unwrap()
}
}

144
src/uint/test_utils.rs Normal file
View File

@@ -0,0 +1,144 @@
use ark_relations::r1cs::{ConstraintSystem, SynthesisError};
use std::ops::RangeInclusive;
use crate::test_utils::{self, modes};
use super::*;
pub(crate) fn test_unary_op<T: PrimUInt, const N: usize, F: PrimeField>(
a: T,
mode: AllocationMode,
test: impl FnOnce(UInt<N, T, F>) -> Result<(), SynthesisError>,
) -> Result<(), SynthesisError> {
let cs = ConstraintSystem::<F>::new_ref();
let a = UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a), mode)?;
test(a)
}
pub(crate) fn test_binary_op<T: PrimUInt, const N: usize, F: PrimeField>(
a: T,
b: T,
mode_a: AllocationMode,
mode_b: AllocationMode,
test: impl FnOnce(UInt<N, T, F>, UInt<N, T, F>) -> Result<(), SynthesisError>,
) -> Result<(), SynthesisError> {
let cs = ConstraintSystem::<F>::new_ref();
let a = UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a), mode_a)?;
let b = UInt::<N, T, F>::new_variable(cs.clone(), || Ok(b), mode_b)?;
test(a, b)
}
pub(crate) fn test_binary_op_with_native<T: PrimUInt, const N: usize, F: PrimeField>(
a: T,
b: T,
mode_a: AllocationMode,
test: impl FnOnce(UInt<N, T, F>, T) -> Result<(), SynthesisError>,
) -> Result<(), SynthesisError> {
let cs = ConstraintSystem::<F>::new_ref();
let a = UInt::<N, T, F>::new_variable(cs.clone(), || Ok(a), mode_a)?;
test(a, b)
}
pub(crate) fn run_binary_random<const ITERATIONS: usize, const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
{
let mut rng = ark_std::test_rng();
for _ in 0..ITERATIONS {
for mode_a in modes() {
let a = T::rand(&mut rng);
for mode_b in modes() {
let b = T::rand(&mut rng);
test_binary_op(a, b, mode_a, mode_b, test)?;
}
}
}
Ok(())
}
pub(crate) fn run_binary_exhaustive<const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
RangeInclusive<T>: Iterator<Item = T>,
{
for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) {
for (mode_b, b) in test_utils::combination(T::min_value()..=T::max_value()) {
test_binary_op(a, b, mode_a, mode_b, test)?;
}
}
Ok(())
}
pub(crate) fn run_binary_random_with_native<const ITERATIONS: usize, const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, T) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
{
let mut rng = ark_std::test_rng();
for _ in 0..ITERATIONS {
for mode_a in modes() {
let a = T::rand(&mut rng);
let b = T::rand(&mut rng);
test_binary_op_with_native(a, b, mode_a, test)?;
}
}
Ok(())
}
pub(crate) fn run_binary_exhaustive_with_native<const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>, T) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
RangeInclusive<T>: Iterator<Item = T>,
{
for (mode_a, a) in test_utils::combination(T::min_value()..=T::max_value()) {
for b in T::min_value()..=T::max_value() {
test_binary_op_with_native(a, b, mode_a, test)?;
}
}
Ok(())
}
pub(crate) fn run_unary_random<const ITERATIONS: usize, const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
{
let mut rng = ark_std::test_rng();
for _ in 0..ITERATIONS {
for mode_a in modes() {
let a = T::rand(&mut rng);
test_unary_op(a, mode_a, test)?;
}
}
Ok(())
}
pub(crate) fn run_unary_exhaustive<const N: usize, T, F>(
test: impl Fn(UInt<N, T, F>) -> Result<(), SynthesisError> + Copy,
) -> Result<(), SynthesisError>
where
T: PrimUInt,
F: PrimeField,
RangeInclusive<T>: Iterator<Item = T>,
{
for (mode, a) in test_utils::combination(T::min_value()..=T::max_value()) {
test_unary_op(a, mode, test)?;
}
Ok(())
}

175
src/uint/xor.rs Normal file
View File

@@ -0,0 +1,175 @@
use ark_ff::Field;
use ark_relations::r1cs::SynthesisError;
use ark_std::{ops::BitXor, ops::BitXorAssign};
use super::*;
impl<const N: usize, T: PrimUInt, F: Field> UInt<N, T, F> {
fn _xor(&self, other: &Self) -> Result<Self, SynthesisError> {
let mut result = self.clone();
for (a, b) in result.bits.iter_mut().zip(&other.bits) {
*a ^= b;
}
result.value = self.value.and_then(|a| Some(a ^ other.value?));
Ok(result)
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<Self> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
/// Outputs `self ^ other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(1))?;
///
/// (a ^ &b).enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: Self) -> Self::Output {
self._xor(other).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<&'a Self> for UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: &Self) -> Self::Output {
self._xor(&other).unwrap()
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXor<UInt<N, T, F>> for &'a UInt<N, T, F> {
type Output = UInt<N, T, F>;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: UInt<N, T, F>) -> Self::Output {
self._xor(&other).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: Field> BitXor<Self> for UInt<N, T, F> {
type Output = Self;
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor(self, other: Self) -> Self::Output {
self._xor(&other).unwrap()
}
}
impl<const N: usize, T: PrimUInt, F: Field> BitXorAssign<Self> for UInt<N, T, F> {
/// Sets `self = self ^ other`.
///
/// If at least one of `self` and `other` are constants, then this method
/// *does not* create any constraints or variables.
///
/// ```
/// # fn main() -> Result<(), ark_relations::r1cs::SynthesisError> {
/// // We'll use the BLS12-381 scalar field for our constraints.
/// use ark_test_curves::bls12_381::Fr;
/// use ark_relations::r1cs::*;
/// use ark_r1cs_std::prelude::*;
///
/// let cs = ConstraintSystem::<Fr>::new_ref();
/// let mut a = UInt8::new_witness(cs.clone(), || Ok(16))?;
/// let b = UInt8::new_witness(cs.clone(), || Ok(17))?;
/// let c = UInt8::new_witness(cs.clone(), || Ok(1))?;
///
/// a ^= b;
/// a.enforce_equal(&c)?;
/// assert!(cs.is_satisfied().unwrap());
/// # Ok(())
/// # }
/// ```
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor_assign(&mut self, other: Self) {
let result = self._xor(&other).unwrap();
*self = result;
}
}
impl<'a, const N: usize, T: PrimUInt, F: Field> BitXorAssign<&'a Self> for UInt<N, T, F> {
#[tracing::instrument(target = "r1cs", skip(self, other))]
fn bitxor_assign(&mut self, other: &'a Self) {
let result = self._xor(other).unwrap();
*self = result;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{AllocVar, AllocationMode},
prelude::EqGadget,
uint::test_utils::{run_binary_exhaustive, run_binary_random},
R1CSVar,
};
use ark_ff::PrimeField;
use ark_test_curves::bls12_381::Fr;
fn uint_xor<T: PrimUInt, const N: usize, F: PrimeField>(
a: UInt<N, T, F>,
b: UInt<N, T, F>,
) -> Result<(), SynthesisError> {
let cs = a.cs().or(b.cs());
let both_constant = a.is_constant() && b.is_constant();
let computed = &a ^ &b;
let expected_mode = if both_constant {
AllocationMode::Constant
} else {
AllocationMode::Witness
};
let expected = UInt::<N, T, F>::new_variable(
cs.clone(),
|| Ok(a.value()? ^ b.value()?),
expected_mode,
)?;
assert_eq!(expected.value(), computed.value());
expected.enforce_equal(&computed)?;
if !both_constant {
assert!(cs.is_satisfied().unwrap());
}
Ok(())
}
#[test]
fn u8_xor() {
run_binary_exhaustive(uint_xor::<u8, 8, Fr>).unwrap()
}
#[test]
fn u16_xor() {
run_binary_random::<1000, 16, _, _>(uint_xor::<u16, 16, Fr>).unwrap()
}
#[test]
fn u32_xor() {
run_binary_random::<1000, 32, _, _>(uint_xor::<u32, 32, Fr>).unwrap()
}
#[test]
fn u64_xor() {
run_binary_random::<1000, 64, _, _>(uint_xor::<u64, 64, Fr>).unwrap()
}
#[test]
fn u128_xor() {
run_binary_random::<1000, 128, _, _>(uint_xor::<u128, 128, Fr>).unwrap()
}
}