diff --git a/subroutines/src/pcs/multilinear_kzg/batching.rs b/subroutines/src/pcs/multilinear_kzg/batching.rs index 17b855b..c235c2b 100644 --- a/subroutines/src/pcs/multilinear_kzg/batching.rs +++ b/subroutines/src/pcs/multilinear_kzg/batching.rs @@ -13,9 +13,7 @@ use crate::{ poly_iop::{prelude::SumCheck, PolyIOP}, IOPProof, }; -use arithmetic::{ - build_eq_x_r_vec, fix_last_variables, DenseMultilinearExtension, VPAuxInfo, VirtualPolynomial, -}; +use arithmetic::{build_eq_x_r_vec, DenseMultilinearExtension, VPAuxInfo, VirtualPolynomial}; use ark_ec::{AffineCurve, PairingEngine, ProjectiveCurve}; use ark_std::{end_timer, log2, start_timer, One, Zero}; use std::{marker::PhantomData, rc::Rc}; @@ -38,10 +36,11 @@ where /// Steps: /// 1. get challenge point t from transcript /// 2. build eq(t,i) for i in [0..k] -/// 3. build \tilde g(i, b) = eq(t, i) * f_i(b) -/// 4. compute \tilde eq -/// 5. run sumcheck on \tilde eq * \tilde g(i, b) -/// 6. build g'(a2) where (a1, a2) is the sumcheck's point +/// 3. build \tilde g_i(b) = eq(t, i) * f_i(b) +/// 4. compute \tilde eq_i(b) = eq(b, point_i) +/// 5. run sumcheck on \sum_i=1..k \tilde eq_i * \tilde g_i +/// 6. build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is +/// the sumcheck's point 7. open g'(X) at point (a2) pub(crate) fn multi_open_internal( prover_param: &PCS::ProverParam, polynomials: &[PCS::Polynomial], @@ -64,7 +63,6 @@ where let num_var = polynomials[0].num_vars; let k = polynomials.len(); let ell = log2(k) as usize; - let merged_num_var = num_var + ell; // challenge point t let t = transcript.get_and_append_challenge_vectors("t".as_ref(), ell)?; @@ -72,40 +70,39 @@ where // eq(t, i) for i in [0..k] let eq_t_i_list = build_eq_x_r_vec(t.as_ref())?; - // \tilde g(i, b) = eq(t, i) * f_i(b) + // \tilde g_i(b) = eq(t, i) * f_i(b) let timer = start_timer!(|| format!("compute tilde g for {} points", points.len())); - let mut tilde_g_eval = vec![E::Fr::zero(); 1 << (ell + num_var)]; - let block_size = 1 << num_var; + let mut tilde_gs = vec![]; for (index, f_i) in polynomials.iter().enumerate() { + let mut tilde_g_eval = vec![E::Fr::zero(); 1 << num_var]; for (j, &f_i_eval) in f_i.iter().enumerate() { - tilde_g_eval[index * block_size + j] = f_i_eval * eq_t_i_list[index]; + tilde_g_eval[j] = f_i_eval * eq_t_i_list[index]; } + tilde_gs.push(Rc::new(DenseMultilinearExtension::from_evaluations_vec( + num_var, + tilde_g_eval, + ))); } - let tilde_g = Rc::new(DenseMultilinearExtension::from_evaluations_vec( - merged_num_var, - tilde_g_eval, - )); end_timer!(timer); let timer = start_timer!(|| format!("compute tilde eq for {} points", points.len())); - let mut tilde_eq_eval = vec![E::Fr::zero(); 1 << (ell + num_var)]; - for (index, point) in points.iter().enumerate() { + let mut tilde_eqs = vec![]; + for point in points.iter() { let eq_b_zi = build_eq_x_r_vec(point)?; - let start = index * block_size; - tilde_eq_eval[start..start + block_size].copy_from_slice(eq_b_zi.as_slice()); + tilde_eqs.push(Rc::new(DenseMultilinearExtension::from_evaluations_vec( + num_var, eq_b_zi, + ))); } - let tilde_eq = Rc::new(DenseMultilinearExtension::from_evaluations_vec( - merged_num_var, - tilde_eq_eval, - )); end_timer!(timer); // built the virtual polynomial for SumCheck - let timer = start_timer!(|| format!("sum check prove of {} variables", num_var + ell)); + let timer = start_timer!(|| format!("sum check prove of {} variables", num_var)); let step = start_timer!(|| "add mle"); - let mut sum_check_vp = VirtualPolynomial::new(num_var + ell); - sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?; + let mut sum_check_vp = VirtualPolynomial::new(num_var); + for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) { + sum_check_vp.add_mle_list([tilde_g.clone(), tilde_eq], E::Fr::one())?; + } end_timer!(step); let proof = match as SumCheck>::prove(&sum_check_vp, transcript) { @@ -120,15 +117,23 @@ where end_timer!(timer); - // (a1, a2) := sumcheck's point - let step = start_timer!(|| "open at a2"); - let a1 = &proof.point[num_var..]; + // a2 := sumcheck's point let a2 = &proof.point[..num_var]; - end_timer!(step); - // build g'(a2) + // build g'(X) = \sum_i=1..k \tilde eq_i(a2) * \tilde g_i(X) where (a2) is the + // sumcheck's point \tilde eq_i(a2) = eq(a2, point_i) let step = start_timer!(|| "evaluate at a2"); - let g_prime = Rc::new(fix_last_variables(&tilde_g, a1)); + let mut g_prime_evals = vec![E::Fr::zero(); 1 << num_var]; + for (tilde_g, point) in tilde_gs.iter().zip(points.iter()) { + let eq_i_a2 = eq_eval(a2, point)?; + for (j, &tilde_g_eval) in tilde_g.iter().enumerate() { + g_prime_evals[j] += tilde_g_eval * eq_i_a2; + } + } + let g_prime = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + num_var, + g_prime_evals, + )); end_timer!(step); let step = start_timer!(|| "pcs open"); @@ -150,8 +155,8 @@ where /// Steps: /// 1. get challenge point t from transcript /// 2. build g' commitment -/// 3. ensure \sum_i eq(t, ) * f_i_evals matches the sum via SumCheck -/// verification 4. verify commitment +/// 3. ensure \sum_i eq(a2, point_i) * eq(t, ) * f_i_evals matches the sum +/// via SumCheck verification 4. verify commitment pub(crate) fn batch_verify_internal( verifier_param: &PCS::VerifierParam, f_i_commitments: &[Commitment], @@ -175,34 +180,33 @@ where let k = f_i_commitments.len(); let ell = log2(k) as usize; - let num_var = proof.sum_check_proof.point.len() - ell; + let num_var = proof.sum_check_proof.point.len(); // challenge point t let t = transcript.get_and_append_challenge_vectors("t".as_ref(), ell)?; - // sum check point (a1, a2) - let a1 = &proof.sum_check_proof.point[num_var..]; + // sum check point (a2) let a2 = &proof.sum_check_proof.point[..num_var]; // build g' commitment - let eq_a1_list = build_eq_x_r_vec(a1)?; let eq_t_list = build_eq_x_r_vec(t.as_ref())?; let mut g_prime_commit = E::G1Affine::zero().into_projective(); - for i in 0..k { - let tmp = eq_a1_list[i] * eq_t_list[i]; + + for (i, point) in points.iter().enumerate() { + let eq_i_a2 = eq_eval(a2, point)?; + let tmp = eq_i_a2 * eq_t_list[i]; g_prime_commit += &f_i_commitments[i].0.mul(tmp); } // ensure \sum_i eq(t, ) * f_i_evals matches the sum via SumCheck - // verification let mut sum = E::Fr::zero(); for (i, &e) in eq_t_list.iter().enumerate().take(k) { sum += e * proof.f_i_eval_at_point_i[i]; } let aux_info = VPAuxInfo { max_degree: 2, - num_variables: num_var + ell, + num_variables: num_var, phantom: PhantomData, }; let subclaim = match as SumCheck>::verify( @@ -219,11 +223,7 @@ where )); }, }; - let mut eq_tilde_eval = E::Fr::zero(); - for (point, &coef) in points.iter().zip(eq_a1_list.iter()) { - eq_tilde_eval += coef * eq_eval(a2, point)?; - } - let tilde_g_eval = subclaim.expected_evaluation / eq_tilde_eval; + let tilde_g_eval = subclaim.expected_evaluation; // verify commitment let res = PCS::verify(