diff --git a/CHANGELOG.md b/CHANGELOG.md index 1eaa427..2f34e61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ ### Features +- [\#71](https://github.com/arkworks-rs/r1cs-std/pull/71) Implement the `Sum` trait for `FpVar`. + ### Improvements ### Bug Fixes diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index e3535df..c85bcbd 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -10,6 +10,7 @@ use crate::{ prelude::*, Assignment, ToConstraintFieldGadget, Vec, }; +use ark_std::iter::Sum; mod cmp; @@ -124,6 +125,36 @@ impl AllocatedFp { AllocatedFp::new(value, variable, self.cs.clone()) } + /// Add many allocated Fp elements together. + /// + /// This does not create any constraints and only creates one linear combination. + pub fn addmany<'a, I: Iterator>(iter: I) -> Self { + let mut cs = ConstraintSystemRef::None; + let mut has_value = true; + let mut value = F::zero(); + let mut new_lc = lc!(); + + for variable in iter { + if !variable.cs.is_none() { + cs = cs.or(variable.cs.clone()); + } + if variable.value.is_none() { + has_value = false; + } else { + value += variable.value.unwrap(); + } + new_lc = new_lc + variable.variable; + } + + let variable = cs.new_lc(new_lc).unwrap(); + + if has_value { + AllocatedFp::new(Some(value), variable, cs.clone()) + } else { + AllocatedFp::new(None, variable, cs.clone()) + } + } + /// Outputs `self - other`. /// /// This does not create any constraints. @@ -1002,3 +1033,61 @@ impl AllocVar for FpVar { } } } + +impl<'a, F: PrimeField> Sum<&'a FpVar> for FpVar { + fn sum>>(iter: I) -> FpVar { + let mut sum_constants = F::zero(); + let sum_variables = FpVar::Var(AllocatedFp::::addmany(iter.filter_map(|x| match x { + FpVar::Constant(c) => { + sum_constants += c; + None + } + FpVar::Var(v) => Some(v), + }))); + + let sum = sum_variables + sum_constants; + sum + } +} + +#[cfg(test)] +mod test { + use crate::alloc::{AllocVar, AllocationMode}; + use crate::eq::EqGadget; + use crate::fields::fp::FpVar; + use crate::R1CSVar; + use ark_relations::r1cs::ConstraintSystem; + use ark_std::{UniformRand, Zero}; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn test_sum_fpvar() { + let mut rng = ark_std::test_rng(); + let cs = ConstraintSystem::new_ref(); + + let mut sum_expected = Fr::zero(); + + let mut v = Vec::new(); + for _ in 0..10 { + let a = Fr::rand(&mut rng); + sum_expected += &a; + v.push( + FpVar::::new_variable(cs.clone(), || Ok(a), AllocationMode::Constant).unwrap(), + ); + } + for _ in 0..10 { + let a = Fr::rand(&mut rng); + sum_expected += &a; + v.push( + FpVar::::new_variable(cs.clone(), || Ok(a), AllocationMode::Witness).unwrap(), + ); + } + + let sum: FpVar = v.iter().sum(); + + sum.enforce_equal(&FpVar::Constant(sum_expected)).unwrap(); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(sum.value().unwrap(), sum_expected); + } +}