Avoid deeply nested lc in EvaluationsVar::interpolate_and_evaluate

This commit is contained in:
winderica
2024-04-08 04:47:26 +08:00
committed by arnaucube
parent 8e71ee527e
commit 5b73436084

View File

@@ -158,11 +158,13 @@ impl<F: PrimeField> EvaluationsVar<F> {
.as_ref() .as_ref()
.expect("lagrange interpolator has not been initialized. "); .expect("lagrange interpolator has not been initialized. ");
let lagrange_coeffs = self.compute_lagrange_coefficients(interpolation_point)?; let lagrange_coeffs = self.compute_lagrange_coefficients(interpolation_point)?;
let mut interpolation: FpVar<F> = FpVar::zero();
for i in 0..lagrange_interpolator.domain_order { let interpolation = lagrange_coeffs
let intermediate = &lagrange_coeffs[i] * &self.evals[i]; .iter()
interpolation += &intermediate .zip(&self.evals)
} .take(lagrange_interpolator.domain_order)
.map(|(coeff, eval)| coeff * eval)
.sum::<FpVar<F>>();
Ok(interpolation) Ok(interpolation)
} }
@@ -208,11 +210,11 @@ impl<F: PrimeField> EvaluationsVar<F> {
let alpha_coset_offset_inv = let alpha_coset_offset_inv =
interpolation_point.mul_by_inverse_unchecked(&self.domain.offset())?; interpolation_point.mul_by_inverse_unchecked(&self.domain.offset())?;
// `res` stores the sum of all lagrange polynomials evaluated at alpha
let mut res = FpVar::<F>::zero();
let domain_size = self.domain.size() as usize; let domain_size = self.domain.size() as usize;
for i in 0..domain_size {
// `evals` stores all lagrange polynomials evaluated at alpha
let evals = (0..domain_size)
.map(|i| {
// a'^{-1} where a is the base coset element // a'^{-1} where a is the base coset element
let subgroup_point_inv = subgroup_points[(domain_size - i) % domain_size]; let subgroup_point_inv = subgroup_points[(domain_size - i) % domain_size];
debug_assert_eq!(subgroup_points[i] * subgroup_point_inv, F::one()); debug_assert_eq!(subgroup_points[i] * subgroup_point_inv, F::one());
@@ -230,9 +232,11 @@ impl<F: PrimeField> EvaluationsVar<F> {
// in the coset. // in the coset.
let lag_coeff = lhs.mul_by_inverse_unchecked(&lag_denom)?; let lag_coeff = lhs.mul_by_inverse_unchecked(&lag_denom)?;
let lag_interpoland = &self.evals[i] * lag_coeff; Ok(&self.evals[i] * lag_coeff)
res += lag_interpoland })
} .collect::<Result<Vec<_>, _>>()?;
let res = evals.iter().sum();
Ok(res) Ok(res)
} }
@@ -378,19 +382,16 @@ mod tests {
#[test] #[test]
fn test_interpolate_constant_offset() { fn test_interpolate_constant_offset() {
for n in [11, 12, 13, 14] {
let mut rng = test_rng(); let mut rng = test_rng();
let poly = DensePolynomial::rand(15, &mut rng);
let gen = Fr::get_root_of_unity(1 << 4).unwrap(); let poly = DensePolynomial::rand((1 << n) - 1, &mut rng);
assert_eq!(gen.pow(&[1 << 4]), Fr::one()); let gen = Fr::get_root_of_unity(1 << n).unwrap();
let domain = Radix2DomainVar::new( assert_eq!(gen.pow(&[1 << n]), Fr::one());
gen, let domain = Radix2DomainVar::new(gen, n, FpVar::constant(Fr::rand(&mut rng))).unwrap();
4, // 2^4 = 16
FpVar::constant(Fr::rand(&mut rng)),
)
.unwrap();
let mut coset_point = domain.offset().value().unwrap(); let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new(); let mut oracle_evals = Vec::new();
for _ in 0..(1 << 4) { for _ in 0..(1 << n) {
oracle_evals.push(poly.evaluate(&coset_point)); oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen; coset_point *= gen;
} }
@@ -415,25 +416,27 @@ mod tests {
assert_eq!(actual, expected); assert_eq!(actual, expected);
assert!(cs.is_satisfied().unwrap()); assert!(cs.is_satisfied().unwrap());
println!("number of constraints: {}", cs.num_constraints()) println!("number of constraints: {}", cs.num_constraints());
}
} }
#[test] #[test]
fn test_interpolate_non_constant_offset() { fn test_interpolate_non_constant_offset() {
for n in [11, 12, 13, 14] {
let mut rng = test_rng(); let mut rng = test_rng();
let poly = DensePolynomial::rand(15, &mut rng); let poly = DensePolynomial::rand((1 << n) - 1, &mut rng);
let gen = Fr::get_root_of_unity(1 << 4).unwrap(); let gen = Fr::get_root_of_unity(1 << n).unwrap();
assert_eq!(gen.pow(&[1 << 4]), Fr::one()); assert_eq!(gen.pow(&[1 << n]), Fr::one());
let cs = ConstraintSystem::new_ref(); let cs = ConstraintSystem::new_ref();
let domain = Radix2DomainVar::new( let domain = Radix2DomainVar::new(
gen, gen,
4, // 2^4 = 16 n,
FpVar::new_witness(ns!(cs, "offset"), || Ok(Fr::rand(&mut rng))).unwrap(), FpVar::new_witness(ns!(cs, "offset"), || Ok(Fr::rand(&mut rng))).unwrap(),
) )
.unwrap(); .unwrap();
let mut coset_point = domain.offset().value().unwrap(); let mut coset_point = domain.offset().value().unwrap();
let mut oracle_evals = Vec::new(); let mut oracle_evals = Vec::new();
for _ in 0..(1 << 4) { for _ in 0..(1 << n) {
oracle_evals.push(poly.evaluate(&coset_point)); oracle_evals.push(poly.evaluate(&coset_point));
coset_point *= gen; coset_point *= gen;
} }
@@ -458,7 +461,8 @@ mod tests {
assert_eq!(actual, expected); assert_eq!(actual, expected);
assert!(cs.is_satisfied().unwrap()); assert!(cs.is_satisfied().unwrap());
println!("number of constraints: {}", cs.num_constraints()) println!("number of constraints: {}", cs.num_constraints());
}
} }
#[test] #[test]