From 768db4eb04879e62be06d2c51a294c6281c8e950 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benedikt=20B=C3=BCnz?= Date: Sun, 20 Nov 2022 09:26:35 -0800 Subject: [PATCH] Optimize verifier eq (#102) * wip we need to be able to do batch opening for different poly sizes or pad poly with zeros * fix small public inputs. Only works for pow2 pubinput Co-authored-by: Charles Chen --- arithmetic/src/lib.rs | 4 +++- arithmetic/src/virtual_polynomial.rs | 17 +++++++++++++++ hyperplonk/Cargo.toml | 1 - hyperplonk/src/mock.rs | 16 +++++++++----- hyperplonk/src/snark.rs | 21 ++++++++++++------- hyperplonk/src/utils.rs | 7 +++++++ .../src/poly_iop/sum_check/verifier.rs | 2 ++ subroutines/src/poly_iop/zero_check/mod.rs | 10 +++------ 8 files changed, 56 insertions(+), 22 deletions(-) diff --git a/arithmetic/src/lib.rs b/arithmetic/src/lib.rs index 2fc2b5f..dfedeaf 100644 --- a/arithmetic/src/lib.rs +++ b/arithmetic/src/lib.rs @@ -12,4 +12,6 @@ pub use multilinear_polynomial::{ }; pub use univariate_polynomial::{build_l, get_uni_domain}; pub use util::{bit_decompose, gen_eval_point, get_batched_nv, get_index}; -pub use virtual_polynomial::{build_eq_x_r, build_eq_x_r_vec, VPAuxInfo, VirtualPolynomial}; +pub use virtual_polynomial::{ + build_eq_x_r, build_eq_x_r_vec, eq_eval, VPAuxInfo, VirtualPolynomial, +}; diff --git a/arithmetic/src/virtual_polynomial.rs b/arithmetic/src/virtual_polynomial.rs index a6f1eb3..8094337 100644 --- a/arithmetic/src/virtual_polynomial.rs +++ b/arithmetic/src/virtual_polynomial.rs @@ -325,6 +325,23 @@ impl VirtualPolynomial { } } +/// Evaluate eq polynomial. +pub fn eq_eval(x: &[F], y: &[F]) -> Result { + if x.len() != y.len() { + return Err(ArithErrors::InvalidParameters( + "x and y have different length".to_string(), + )); + } + let start = start_timer!(|| "eq_eval"); + let mut res = F::one(); + for (&xi, &yi) in x.iter().zip(y.iter()) { + let xi_yi = xi * yi; + res *= xi_yi + xi_yi - xi - yi + F::one(); + } + end_timer!(start); + Ok(res) +} + /// This function build the eq(x, r) polynomial for any given r. /// /// Evaluate diff --git a/hyperplonk/Cargo.toml b/hyperplonk/Cargo.toml index 5045d17..d1dd539 100644 --- a/hyperplonk/Cargo.toml +++ b/hyperplonk/Cargo.toml @@ -24,7 +24,6 @@ rayon = { version = "1.5.2", default-features = false, optional = true } [dev-dependencies] ark-bls12-381 = { version = "0.3.0", default-features = false, features = [ "curve" ] } - # Benchmarks [[bench]] name = "hyperplonk-benches" diff --git a/hyperplonk/src/mock.rs b/hyperplonk/src/mock.rs index 6af98a2..7fa4372 100644 --- a/hyperplonk/src/mock.rs +++ b/hyperplonk/src/mock.rs @@ -10,6 +10,7 @@ use crate::{ }; pub struct MockCircuit { + pub public_inputs: Vec, pub witnesses: Vec>, pub index: HyperPlonkIndex, } @@ -85,10 +86,12 @@ impl MockCircuit { witnesses[i].append(cur_witness[i]); } } + let pub_input_len = ark_std::cmp::min(4, num_constraints); + let public_inputs = witnesses[0].0[0..pub_input_len].to_vec(); let params = HyperPlonkParams { num_constraints, - num_pub_input: num_constraints, + num_pub_input: public_inputs.len(), gate_func: gate.clone(), }; @@ -99,7 +102,11 @@ impl MockCircuit { selectors, }; - Self { witnesses, index } + Self { + public_inputs, + witnesses, + index, + } } pub fn is_satisfied(&self) -> bool { @@ -177,7 +184,6 @@ mod test { assert!(circuit.is_satisfied()); let index = circuit.index; - // generate pk and vks let (pk, vk) = as HyperPlonkSNARK>>::preprocess( @@ -187,14 +193,14 @@ mod test { let proof = as HyperPlonkSNARK>>::prove( &pk, - &circuit.witnesses[0].0, + &circuit.public_inputs, &circuit.witnesses, )?; let verify = as HyperPlonkSNARK>>::verify( &vk, - &circuit.witnesses[0].0, + &circuit.public_inputs, &proof, )?; assert!(verify); diff --git a/hyperplonk/src/snark.rs b/hyperplonk/src/snark.rs index 4318d99..e1615ac 100644 --- a/hyperplonk/src/snark.rs +++ b/hyperplonk/src/snark.rs @@ -324,8 +324,11 @@ where // - 4.4. public input consistency checks // - pi_poly(r_pi) where r_pi is sampled from transcript let r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?; - let tmp_point = [vec![E::Fr::zero(); num_vars - ell], r_pi].concat(); - pcs_acc.insert_poly_and_points(&witness_polys[0], &witness_commits[0], &tmp_point); + // padded with zeros + let r_pi_padded = [r_pi, vec![E::Fr::zero(); num_vars - ell]].concat(); + // Evaluate witness_poly[0] at r_pi||0s which is equal to public_input evaluated + // at r_pi. Assumes that public_input is a power of 2 + pcs_acc.insert_poly_and_points(&witness_polys[0], &witness_commits[0], &r_pi_padded); end_timer!(step); // ======================================================================= @@ -515,7 +518,7 @@ where // ======================================================================= // 3. Verify the opening against the commitment // ======================================================================= - let step = start_timer!(|| "verify commitments"); + let step = start_timer!(|| "assemble commitments"); // generate evaluation points and commitments let mut comms = vec![]; @@ -535,7 +538,6 @@ where points.push(perm_check_point_0.clone()); points.push(perm_check_point_1.clone()); points.push(prod_final_query_point); - // frac(x)'s points comms.push(proof.perm_check_proof.frac_comm); comms.push(proof.perm_check_proof.frac_comm); @@ -575,21 +577,24 @@ where // - 4.4. public input consistency checks // - pi_poly(r_pi) where r_pi is sampled from transcript let r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?; - let tmp_point = [vec![E::Fr::zero(); num_vars - ell], r_pi].concat(); + // check public evaluation let pi_poly = DenseMultilinearExtension::from_evaluations_slice(ell as usize, pub_input); - let expect_pi_eval = evaluate_opt(&pi_poly, &tmp_point[..]); + let expect_pi_eval = evaluate_opt(&pi_poly, &r_pi[..]); if expect_pi_eval != *pi_eval { return Err(HyperPlonkErrors::InvalidProver(format!( "Public input eval mismatch: got {}, expect {}", pi_eval, expect_pi_eval, ))); } - comms.push(proof.witness_commits[0]); - points.push(tmp_point); + let r_pi_padded = [r_pi, vec![E::Fr::zero(); num_vars - ell]].concat(); + comms.push(proof.witness_commits[0]); + points.push(r_pi_padded); assert_eq!(comms.len(), proof.batch_openings.f_i_eval_at_point_i.len()); + end_timer!(step); + let step = start_timer!(|| "PCS batch verify"); // check proof let res = PCS::batch_verify( &vk.pcs_param, diff --git a/hyperplonk/src/utils.rs b/hyperplonk/src/utils.rs index da11f13..69a6d5a 100644 --- a/hyperplonk/src/utils.rs +++ b/hyperplonk/src/utils.rs @@ -137,6 +137,13 @@ pub(crate) fn prover_sanity_check( params.num_pub_input ))); } + if !pub_input.len().is_power_of_two() { + return Err(HyperPlonkErrors::InvalidProver(format!( + "Public input length is not power of two: got {}", + pub_input.len(), + ))); + } + // witnesses length for (i, w) in witnesses.iter().enumerate() { if w.0.len() != params.num_constraints { diff --git a/subroutines/src/poly_iop/sum_check/verifier.rs b/subroutines/src/poly_iop/sum_check/verifier.rs index 90edb48..99377da 100644 --- a/subroutines/src/poly_iop/sum_check/verifier.rs +++ b/subroutines/src/poly_iop/sum_check/verifier.rs @@ -178,6 +178,8 @@ impl SumCheckVerifier for IOPVerifierState { /// This implementation is linear in number of inputs in terms of field /// operations. It also has a quadratic term in primitive operations which is /// negligible compared to field operations. +/// TODO: The quadratic term can be removed by precomputing the lagrange +/// coefficients. fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> Result { let start = start_timer!(|| "sum check interpolate uni poly opt"); diff --git a/subroutines/src/poly_iop/zero_check/mod.rs b/subroutines/src/poly_iop/zero_check/mod.rs index 12d3424..144e8af 100644 --- a/subroutines/src/poly_iop/zero_check/mod.rs +++ b/subroutines/src/poly_iop/zero_check/mod.rs @@ -3,9 +3,8 @@ use std::fmt::Debug; use crate::poly_iop::{errors::PolyIOPErrors, sum_check::SumCheck, PolyIOP}; -use arithmetic::build_eq_x_r; +use arithmetic::eq_eval; use ark_ff::PrimeField; -use ark_poly::MultilinearExtension; use ark_std::{end_timer, start_timer}; use transcript::IOPTranscript; @@ -103,11 +102,8 @@ impl ZeroCheck for PolyIOP { // expected_eval = sumcheck.expect_eval/eq(v, r) // where v = sum_check_sub_claim.point - let eq_x_r = build_eq_x_r(&r)?; - let expected_evaluation = sum_subclaim.expected_evaluation - / eq_x_r.evaluate(&sum_subclaim.point).ok_or_else(|| { - PolyIOPErrors::InvalidParameters("evaluation dimension does not match".to_string()) - })?; + let eq_x_r_eval = eq_eval(&sum_subclaim.point, &r)?; + let expected_evaluation = sum_subclaim.expected_evaluation / eq_x_r_eval; end_timer!(start); Ok(ZeroCheckSubClaim {