diff --git a/.gitignore b/.gitignore index 7d45714..4e98029 100644 --- a/.gitignore +++ b/.gitignore @@ -296,4 +296,7 @@ TSWLatexianTemp* # REVTeX puts footnotes in the bibliography by default, unless the nofootinbib # option is specified. Footnotes are the stored in a file with suffix Notes.bib. # Uncomment the next line to have this generated file ignored. -#*Notes.bib \ No newline at end of file +#*Notes.bib + + +*.txt \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 9467f58..31bc286 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,5 @@ members = [ "hyperplonk", "poly-iop", "transcript", + "util" ] diff --git a/arithmetic/Cargo.toml b/arithmetic/Cargo.toml index 669ad28..6959c66 100644 --- a/arithmetic/Cargo.toml +++ b/arithmetic/Cargo.toml @@ -19,6 +19,7 @@ rayon = { version = "1.5.2", default-features = false, optional = true } [dev-dependencies] ark-ec = { version = "^0.3.0", default-features = false } +criterion = "0.3.0" [features] # default = [ "parallel", "print-trace" ] @@ -31,4 +32,10 @@ parallel = [ ] print-trace = [ "ark-std/print-trace" - ] \ No newline at end of file + ] + + +[[bench]] +name = "mle_eval" +path = "benches/bench.rs" +harness = false \ No newline at end of file diff --git a/arithmetic/benches/bench.rs b/arithmetic/benches/bench.rs new file mode 100644 index 0000000..57ffafd --- /dev/null +++ b/arithmetic/benches/bench.rs @@ -0,0 +1,37 @@ +#[macro_use] +extern crate criterion; + +use arithmetic::fix_variables; +use ark_bls12_381::Fr; +use ark_ff::Field; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_std::{ops::Range, test_rng}; +use criterion::{black_box, BenchmarkId, Criterion}; + +const NUM_VARIABLES_RANGE: Range = 10..21; + +fn evaluation_op_bench(c: &mut Criterion) { + let mut rng = test_rng(); + let mut group = c.benchmark_group("Evaluate"); + for nv in NUM_VARIABLES_RANGE { + group.bench_with_input(BenchmarkId::new("evaluate native", nv), &nv, |b, &nv| { + let poly = DenseMultilinearExtension::::rand(nv, &mut rng); + let point: Vec<_> = (0..nv).map(|_| F::rand(&mut rng)).collect(); + b.iter(|| black_box(poly.evaluate(&point).unwrap())) + }); + + group.bench_with_input(BenchmarkId::new("evaluate optimized", nv), &nv, |b, &nv| { + let poly = DenseMultilinearExtension::::rand(nv, &mut rng); + let point: Vec<_> = (0..nv).map(|_| F::rand(&mut rng)).collect(); + b.iter(|| black_box(fix_variables(&poly, &point))) + }); + } + group.finish(); +} + +fn bench_bls_381(c: &mut Criterion) { + evaluation_op_bench::(c); +} + +criterion_group!(benches, bench_bls_381); +criterion_main!(benches); diff --git a/arithmetic/src/lib.rs b/arithmetic/src/lib.rs index 82bc047..43576ec 100644 --- a/arithmetic/src/lib.rs +++ b/arithmetic/src/lib.rs @@ -1,7 +1,15 @@ mod errors; mod multilinear_polynomial; +mod univariate_polynomial; +mod util; mod virtual_polynomial; pub use errors::ArithErrors; -pub use multilinear_polynomial::{random_zero_mle_list, DenseMultilinearExtension}; +pub use multilinear_polynomial::{ + evaluate_no_par, evaluate_opt, fix_first_variable, fix_variables, identity_permutation_mle, + merge_polynomials, random_mle_list, random_permutation_mle, random_zero_mle_list, + DenseMultilinearExtension, +}; +pub use univariate_polynomial::{build_l, get_uni_domain}; +pub use util::{bit_decompose, gen_eval_point, get_batched_nv, get_index}; pub use virtual_polynomial::{build_eq_x_r, VPAuxInfo, VirtualPolynomial}; diff --git a/arithmetic/src/multilinear_polynomial.rs b/arithmetic/src/multilinear_polynomial.rs index 077edec..1bcef2b 100644 --- a/arithmetic/src/multilinear_polynomial.rs +++ b/arithmetic/src/multilinear_polynomial.rs @@ -1,9 +1,49 @@ -use ark_ff::PrimeField; +use crate::{util::get_batched_nv, ArithErrors}; +use ark_ff::{Field, PrimeField}; +use ark_poly::MultilinearExtension; use ark_std::{end_timer, rand::RngCore, start_timer}; +#[cfg(feature = "parallel")] +use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; use std::rc::Rc; pub use ark_poly::DenseMultilinearExtension; +/// Sample a random list of multilinear polynomials. +/// Returns +/// - the list of polynomials, +/// - its sum of polynomial evaluations over the boolean hypercube. +pub fn random_mle_list( + nv: usize, + degree: usize, + rng: &mut R, +) -> (Vec>>, F) { + let start = start_timer!(|| "sample random mle list"); + let mut multiplicands = Vec::with_capacity(degree); + for _ in 0..degree { + multiplicands.push(Vec::with_capacity(1 << nv)) + } + let mut sum = F::zero(); + + for _ in 0..(1 << nv) { + let mut product = F::one(); + + for e in multiplicands.iter_mut() { + let val = F::rand(rng); + e.push(val); + product *= val; + } + sum += product; + } + + let list = multiplicands + .into_iter() + .map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) + .collect(); + + end_timer!(start); + (list, sum) +} + // Build a randomize list of mle-s whose sum is zero. pub fn random_zero_mle_list( nv: usize, @@ -31,3 +71,142 @@ pub fn random_zero_mle_list( end_timer!(start); list } + +/// An MLE that represent an identity permutation: `f(index) \mapto index` +pub fn identity_permutation_mle( + num_vars: usize, +) -> Rc> { + let s_id_vec = (0..1u64 << num_vars).map(F::from).collect(); + Rc::new(DenseMultilinearExtension::from_evaluations_vec( + num_vars, s_id_vec, + )) +} + +/// An MLE that represent a random permutation +pub fn random_permutation_mle( + num_vars: usize, + rng: &mut R, +) -> Rc> { + let len = 1u64 << num_vars; + let mut s_id_vec: Vec = (0..len).map(F::from).collect(); + let mut s_perm_vec = vec![]; + for _ in 0..len { + let index = rng.next_u64() as usize % s_id_vec.len(); + s_perm_vec.push(s_id_vec.remove(index)); + } + Rc::new(DenseMultilinearExtension::from_evaluations_vec( + num_vars, s_perm_vec, + )) +} + +pub fn evaluate_opt(poly: &DenseMultilinearExtension, point: &[F]) -> F { + assert_eq!(poly.num_vars, point.len()); + fix_variables(poly, point).evaluations[0] +} + +pub fn fix_variables( + poly: &DenseMultilinearExtension, + partial_point: &[F], +) -> DenseMultilinearExtension { + assert!( + partial_point.len() <= poly.num_vars, + "invalid size of partial point" + ); + let nv = poly.num_vars; + let mut poly = poly.evaluations.to_vec(); + let dim = partial_point.len(); + // evaluate single variable of partial point from left to right + for (i, point) in partial_point.iter().enumerate().take(dim) { + poly = fix_one_variable_helper(&poly, nv - i, point); + } + + DenseMultilinearExtension::::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) +} + +pub fn fix_first_variable( + poly: &DenseMultilinearExtension, + partial_point: &F, +) -> DenseMultilinearExtension { + assert!(poly.num_vars != 0, "invalid size of partial point"); + + let nv = poly.num_vars; + let res = fix_one_variable_helper(&poly.evaluations, nv, partial_point); + DenseMultilinearExtension::::from_evaluations_slice(nv - 1, &res) +} + +fn fix_one_variable_helper(data: &[F], nv: usize, point: &F) -> Vec { + let mut res = vec![F::zero(); 1 << (nv - 1)]; + let one_minus_p = F::one() - point; + + // evaluate single variable of partial point from left to right + #[cfg(not(feature = "parallel"))] + for b in 0..(1 << (nv - 1)) { + res[b] = data[b << 1] * one_minus_p + data[(b << 1) + 1] * point; + } + + #[cfg(feature = "parallel")] + if nv >= 13 { + // on my computer we parallelization doesn't help till nv >= 13 + res.par_iter_mut().enumerate().for_each(|(i, x)| { + *x = data[i << 1] * one_minus_p + data[(i << 1) + 1] * point; + }); + } else { + for b in 0..(1 << (nv - 1)) { + res[b] = data[b << 1] * one_minus_p + data[(b << 1) + 1] * point; + } + } + + res +} + +pub fn evaluate_no_par(poly: &DenseMultilinearExtension, point: &[F]) -> F { + assert_eq!(poly.num_vars, point.len()); + fix_variables_no_par(poly, point).evaluations[0] +} + +fn fix_variables_no_par( + poly: &DenseMultilinearExtension, + partial_point: &[F], +) -> DenseMultilinearExtension { + assert!( + partial_point.len() <= poly.num_vars, + "invalid size of partial point" + ); + let nv = poly.num_vars; + let mut poly = poly.evaluations.to_vec(); + let dim = partial_point.len(); + // evaluate single variable of partial point from left to right + for i in 1..dim + 1 { + let r = partial_point[i - 1]; + let one_minus_r = F::one() - r; + for b in 0..(1 << (nv - i)) { + poly[b] = poly[b << 1] * one_minus_r + poly[(b << 1) + 1] * r; + } + } + DenseMultilinearExtension::from_evaluations_slice(nv - dim, &poly[..(1 << (nv - dim))]) +} + +/// merge a set of polynomials. Returns an error if the +/// polynomials do not share a same number of nvs. +pub fn merge_polynomials( + polynomials: &[Rc>], +) -> Result>, ArithErrors> { + let nv = polynomials[0].num_vars(); + for poly in polynomials.iter() { + if nv != poly.num_vars() { + return Err(ArithErrors::InvalidParameters( + "num_vars do not match for polynomials".to_string(), + )); + } + } + + let merged_nv = get_batched_nv(nv, polynomials.len()); + let mut scalars = vec![]; + for poly in polynomials.iter() { + scalars.extend_from_slice(poly.to_evaluations().as_slice()); + } + scalars.extend_from_slice(vec![F::zero(); (1 << merged_nv) - scalars.len()].as_ref()); + Ok(Rc::new(DenseMultilinearExtension::from_evaluations_vec( + merged_nv, scalars, + ))) +} diff --git a/arithmetic/src/univariate_polynomial.rs b/arithmetic/src/univariate_polynomial.rs new file mode 100644 index 0000000..bbc525e --- /dev/null +++ b/arithmetic/src/univariate_polynomial.rs @@ -0,0 +1,340 @@ +// TODO: remove +#![allow(dead_code)] + +use crate::{bit_decompose, ArithErrors}; +use ark_ff::PrimeField; +use ark_poly::{ + univariate::DensePolynomial, EvaluationDomain, Evaluations, Radix2EvaluationDomain, +}; +use ark_std::log2; + +/// Given a list of points, build `l(points)` which is a list of univariate +/// polynomials that goes through the points; extend the dimension of the points +/// by `log(points.len())` if `with_suffix` is set. +pub fn build_l( + points: &[Vec], + domain: &Radix2EvaluationDomain, + with_suffix: bool, +) -> Result>, ArithErrors> { + let mut uni_polys = Vec::new(); + if with_suffix { + // 1.1 build the indexes and the univariate polys that go through the indexes + let prefix_len = log2(points.len()) as usize; + let indexes: Vec> = (0..points.len()) + .map(|x| bit_decompose(x as u64, prefix_len)) + .collect(); + for i in 0..prefix_len { + let eval: Vec = indexes + .iter() + .map(|x| F::from(x[prefix_len - i - 1])) + .collect(); + + uni_polys.push(Evaluations::from_vec_and_domain(eval, *domain).interpolate()); + } + } + // 1.2 build the actual univariate polys that go through the points + uni_polys.extend_from_slice(build_l_internal(points, domain)?.as_slice()); + + Ok(uni_polys) +} + +/// Given a list of points, build `l(points)` which is a list of univariate +/// polynomials that goes through the points. +pub(crate) fn build_l_internal( + points: &[Vec], + domain: &Radix2EvaluationDomain, +) -> Result>, ArithErrors> { + let mut uni_polys = Vec::new(); + let num_var = points[0].len(); + // build the actual univariate polys that go through the points + for i in 0..num_var { + let mut eval: Vec = points.iter().map(|x| x[i]).collect(); + eval.extend_from_slice(vec![F::zero(); domain.size as usize - eval.len()].as_slice()); + uni_polys.push(Evaluations::from_vec_and_domain(eval, *domain).interpolate()) + } + Ok(uni_polys) +} + +/// get the domain for the univariate polynomial +#[inline] +pub fn get_uni_domain( + uni_poly_degree: usize, +) -> Result, ArithErrors> { + let domain = match Radix2EvaluationDomain::::new(uni_poly_degree) { + Some(p) => p, + None => { + return Err(ArithErrors::InvalidParameters( + "failed to build radix 2 domain".to_string(), + )) + }, + }; + Ok(domain) +} + +#[cfg(test)] +mod test { + use super::*; + use ark_bls12_381::Fr; + use ark_ff::{field_new, One}; + use ark_poly::UVPolynomial; + + #[test] + fn test_build_l_with_suffix() -> Result<(), ArithErrors> { + test_build_l_with_suffix_helper::() + } + + fn test_build_l_with_suffix_helper() -> Result<(), ArithErrors> { + // point 1 is [1, 2] + let point1 = vec![Fr::from(1u64), Fr::from(2u64)]; + + // point 2 is [3, 4] + let point2 = vec![Fr::from(3u64), Fr::from(4u64)]; + + // point 3 is [5, 6] + let point3 = vec![Fr::from(5u64), Fr::from(6u64)]; + + { + let domain = get_uni_domain::(2)?; + let l = build_l(&[point1.clone(), point2.clone()], &domain, true)?; + + // roots: [1, -1] + // l0 = -1/2 * x + 1/2 + // l1 = -x + 2 + // l2 = -x + 3 + let l0 = DensePolynomial::from_coefficients_vec(vec![ + Fr::one() / Fr::from(2u64), + -Fr::one() / Fr::from(2u64), + ]); + let l1 = DensePolynomial::from_coefficients_vec(vec![Fr::from(2u64), -Fr::one()]); + let l2 = DensePolynomial::from_coefficients_vec(vec![Fr::from(3u64), -Fr::one()]); + + assert_eq!(l0, l[0], "l0 not equal"); + assert_eq!(l1, l[1], "l1 not equal"); + assert_eq!(l2, l[2], "l2 not equal"); + } + + { + let domain = get_uni_domain::(3)?; + let l = build_l(&[point1, point2, point3], &domain, true)?; + + // sage: q = 52435875175126190479447740508185965837690552500527637822603658699938581184513 + // sage: P. = PolynomialRing(Zmod(q)) + // sage: root1 = 1 + // sage: root2 = 0x8D51CCCE760304D0EC030002760300000001000000000000 + // sage: root3 = -1 + // sage: root4 = -root2 + // Arkwork's code is a bit wired: it also interpolate (root4, 0) + // which returns a degree 3 polynomial, instead of degree 2 + + // ======================== + // l0: [0, 0, 1] + // ======================== + // sage: points = [(root1, 0), (root2, 0), (root3, 1), (root4, 0)] + // sage: P.lagrange_polynomial(points) + // 13108968793781547619861935127046491459422638125131909455650914674984645296128*x^3 + + // 39326906381344642859585805381139474378267914375395728366952744024953935888385*x^2 + + // 13108968793781547619861935127046491459422638125131909455650914674984645296128*x + + // 39326906381344642859585805381139474378267914375395728366952744024953935888385 + let l0 = DensePolynomial::from_coefficients_vec(vec![ + field_new!( + Fr, + "39326906381344642859585805381139474378267914375395728366952744024953935888385" + ), + field_new!( + Fr, + "13108968793781547619861935127046491459422638125131909455650914674984645296128" + ), + field_new!( + Fr, + "39326906381344642859585805381139474378267914375395728366952744024953935888385" + ), + field_new!( + Fr, + "13108968793781547619861935127046491459422638125131909455650914674984645296128" + ), + ]); + + // ======================== + // l1: [0, 1, 0] + // ======================== + // sage: points = [(root1, 0), (root2, 1), (root3, 0), (root4, 0)] + // sage: P.lagrange_polynomial(points) + // 866286206518413079694067382671935694567563117191340490752*x^3 + + // 13108968793781547619861935127046491459422638125131909455650914674984645296128*x^2 + + // 52435875175126190478581454301667552757996485117855702128036095582747240693761*x + + // 39326906381344642859585805381139474378267914375395728366952744024953935888385 + let l1 = DensePolynomial::from_coefficients_vec(vec![ + field_new!( + Fr, + "39326906381344642859585805381139474378267914375395728366952744024953935888385" + ), + field_new!( + Fr, + "52435875175126190478581454301667552757996485117855702128036095582747240693761" + ), + field_new!( + Fr, + "13108968793781547619861935127046491459422638125131909455650914674984645296128" + ), + field_new!( + Fr, + "866286206518413079694067382671935694567563117191340490752" + ), + ]); + + // ======================== + // l2: [1, 3, 5] + // ======================== + // sage: points = [(root1, 1), (root2, 3), (root3, 5), (root4, 0)] + // sage: P.lagrange_polynomial(points) + // 2598858619555239239082202148015807083702689351574021472255*x^3 + + // 13108968793781547619861935127046491459422638125131909455650914674984645296129*x^2 + + // 52435875175126190476848881888630726598608350352511830738900969348364559712256*x + + // 39326906381344642859585805381139474378267914375395728366952744024953935888387 + let l2 = DensePolynomial::from_coefficients_vec(vec![ + field_new!( + Fr, + "39326906381344642859585805381139474378267914375395728366952744024953935888387" + ), + field_new!( + Fr, + "52435875175126190476848881888630726598608350352511830738900969348364559712256" + ), + field_new!( + Fr, + "13108968793781547619861935127046491459422638125131909455650914674984645296129" + ), + field_new!( + Fr, + "2598858619555239239082202148015807083702689351574021472255" + ), + ]); + + // ======================== + // l3: [2, 4, 6] + // ======================== + // sage: points = [(root1, 2), (root2, 4), (root3, 6), (root4, 0)] + // sage: P.lagrange_polynomial(points) + // 3465144826073652318776269530687742778270252468765361963007*x^3 + + // x^2 + + // 52435875175126190475982595682112313518914282969839895044333406231173219221504*x + + // 3 + let l3 = DensePolynomial::from_coefficients_vec(vec![ + Fr::from(3u64), + field_new!( + Fr, + "52435875175126190475982595682112313518914282969839895044333406231173219221504" + ), + Fr::one(), + field_new!( + Fr, + "3465144826073652318776269530687742778270252468765361963007" + ), + ]); + + assert_eq!(l0, l[0], "l0 not equal"); + assert_eq!(l1, l[1], "l1 not equal"); + assert_eq!(l2, l[2], "l2 not equal"); + assert_eq!(l3, l[3], "l3 not equal"); + } + Ok(()) + } + + #[test] + fn test_build_l() -> Result<(), ArithErrors> { + test_build_l_helper::() + } + + fn test_build_l_helper() -> Result<(), ArithErrors> { + // point 1 is [1, 2] + let point1 = vec![Fr::from(1u64), Fr::from(2u64)]; + + // point 2 is [3, 4] + let point2 = vec![Fr::from(3u64), Fr::from(4u64)]; + + // point 3 is [5, 6] + let point3 = vec![Fr::from(5u64), Fr::from(6u64)]; + + { + let domain = get_uni_domain::(2)?; + let l = build_l(&[point1.clone(), point2.clone()], &domain, false)?; + + // roots: [1, -1] + // l0 = -x + 2 + // l1 = -x + 3 + let l0 = DensePolynomial::from_coefficients_vec(vec![Fr::from(2u64), -Fr::one()]); + let l1 = DensePolynomial::from_coefficients_vec(vec![Fr::from(3u64), -Fr::one()]); + + assert_eq!(l0, l[0], "l0 not equal"); + assert_eq!(l1, l[1], "l1 not equal"); + } + + { + let domain = get_uni_domain::(3)?; + let l = build_l(&[point1, point2, point3], &domain, false)?; + + // sage: q = 52435875175126190479447740508185965837690552500527637822603658699938581184513 + // sage: P. = PolynomialRing(Zmod(q)) + // sage: root1 = 1 + // sage: root2 = 0x8D51CCCE760304D0EC030002760300000001000000000000 + // sage: root3 = -1 + // sage: root4 = -root2 + // Arkwork's code is a bit wired: it also interpolate (root4, 0) + // which returns a degree 3 polynomial, instead of degree 2 + + // ======================== + // l0: [1, 3, 5] + // ======================== + // sage: points = [(root1, 1), (root2, 3), (root3, 5), (root4, 0)] + // sage: P.lagrange_polynomial(points) + // 2598858619555239239082202148015807083702689351574021472255*x^3 + + // 13108968793781547619861935127046491459422638125131909455650914674984645296129*x^2 + + // 52435875175126190476848881888630726598608350352511830738900969348364559712256*x + + // 39326906381344642859585805381139474378267914375395728366952744024953935888387 + let l0 = DensePolynomial::from_coefficients_vec(vec![ + field_new!( + Fr, + "39326906381344642859585805381139474378267914375395728366952744024953935888387" + ), + field_new!( + Fr, + "52435875175126190476848881888630726598608350352511830738900969348364559712256" + ), + field_new!( + Fr, + "13108968793781547619861935127046491459422638125131909455650914674984645296129" + ), + field_new!( + Fr, + "2598858619555239239082202148015807083702689351574021472255" + ), + ]); + + // ======================== + // l1: [2, 4, 6] + // ======================== + // sage: points = [(root1, 2), (root2, 4), (root3, 6), (root4, 0)] + // sage: P.lagrange_polynomial(points) + // 3465144826073652318776269530687742778270252468765361963007*x^3 + + // x^2 + + // 52435875175126190475982595682112313518914282969839895044333406231173219221504*x + + // 3 + let l1 = DensePolynomial::from_coefficients_vec(vec![ + Fr::from(3u64), + field_new!( + Fr, + "52435875175126190475982595682112313518914282969839895044333406231173219221504" + ), + Fr::one(), + field_new!( + Fr, + "3465144826073652318776269530687742778270252468765361963007" + ), + ]); + + assert_eq!(l0, l[0], "l0 not equal"); + assert_eq!(l1, l[1], "l1 not equal"); + } + Ok(()) + } +} diff --git a/arithmetic/src/util.rs b/arithmetic/src/util.rs new file mode 100644 index 0000000..c7f21f5 --- /dev/null +++ b/arithmetic/src/util.rs @@ -0,0 +1,96 @@ +use ark_ff::PrimeField; +use ark_std::log2; + +/// Decompose an integer into a binary vector in little endian. +pub fn bit_decompose(input: u64, num_var: usize) -> Vec { + let mut res = Vec::with_capacity(num_var); + let mut i = input; + for _ in 0..num_var { + res.push(i & 1 == 1); + i >>= 1; + } + res +} + +/// given the evaluation input `point` of the `index`-th polynomial, +/// obtain the evaluation point in the merged polynomial +pub fn gen_eval_point(index: usize, index_len: usize, point: &[F]) -> Vec { + let index_vec: Vec = bit_decompose(index as u64, index_len) + .into_iter() + .map(|x| F::from(x)) + .collect(); + [point, &index_vec].concat() +} + +/// Return the number of variables that one need for an MLE to +/// batch the list of MLEs +#[inline] +pub fn get_batched_nv(num_var: usize, polynomials_len: usize) -> usize { + num_var + log2(polynomials_len) as usize +} + +// Input index +// - `i := (i_0, ...i_{n-1})`, +// - `num_vars := n` +// return three elements: +// - `x0 := (i_1, ..., i_{n-1}, 0)` +// - `x1 := (i_1, ..., i_{n-1}, 1)` +// - `sign := i_0` +#[inline] +pub fn get_index(i: usize, num_vars: usize) -> (usize, usize, bool) { + let bit_sequence = bit_decompose(i as u64, num_vars); + + // the last bit comes first here because of LE encoding + let x0 = project(&[[false].as_ref(), bit_sequence[..num_vars - 1].as_ref()].concat()) as usize; + let x1 = project(&[[true].as_ref(), bit_sequence[..num_vars - 1].as_ref()].concat()) as usize; + + (x0, x1, bit_sequence[num_vars - 1]) +} + +/// Project a little endian binary vector into an integer. +#[inline] +pub(crate) fn project(input: &[bool]) -> u64 { + let mut res = 0; + for &e in input.iter().rev() { + res <<= 1; + res += e as u64; + } + res +} + +#[cfg(test)] +mod test { + use super::{bit_decompose, get_index, project}; + use ark_std::{rand::RngCore, test_rng}; + + #[test] + fn test_decomposition() { + let mut rng = test_rng(); + for _ in 0..100 { + let t = rng.next_u64(); + let b = bit_decompose(t, 64); + let r = project(&b); + assert_eq!(t, r) + } + } + + #[test] + fn test_get_index() { + let a = 0b1010; + let (x0, x1, sign) = get_index(a, 4); + assert_eq!(x0, 0b0100); + assert_eq!(x1, 0b0101); + assert!(sign); + + let (x0, x1, sign) = get_index(a, 5); + assert_eq!(x0, 0b10100); + assert_eq!(x1, 0b10101); + assert!(!sign); + + let a = 0b1111; + let (x0, x1, sign) = get_index(a, 4); + assert_eq!(x0, 0b1110); + assert_eq!(x1, 0b1111); + assert!(sign); + } +} diff --git a/arithmetic/src/virtual_polynomial.rs b/arithmetic/src/virtual_polynomial.rs index e44d9b2..bb6b44c 100644 --- a/arithmetic/src/virtual_polynomial.rs +++ b/arithmetic/src/virtual_polynomial.rs @@ -1,7 +1,7 @@ //! This module defines our main mathematical object `VirtualPolynomial`; and //! various functions associated with it. -use crate::{errors::ArithErrors, multilinear_polynomial::random_zero_mle_list}; +use crate::{errors::ArithErrors, multilinear_polynomial::random_zero_mle_list, random_mle_list}; use ark_ff::PrimeField; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_serialize::{CanonicalSerialize, SerializationError, Write}; @@ -324,42 +324,6 @@ impl VirtualPolynomial { } } -/// Sample a random list of multilinear polynomials. -/// Returns -/// - the list of polynomials, -/// - its sum of polynomial evaluations over the boolean hypercube. -fn random_mle_list( - nv: usize, - degree: usize, - rng: &mut R, -) -> (Vec>>, F) { - let start = start_timer!(|| "sample random mle list"); - let mut multiplicands = Vec::with_capacity(degree); - for _ in 0..degree { - multiplicands.push(Vec::with_capacity(1 << nv)) - } - let mut sum = F::zero(); - - for _ in 0..(1 << nv) { - let mut product = F::one(); - - for e in multiplicands.iter_mut() { - let val = F::rand(rng); - e.push(val); - product *= val; - } - sum += product; - } - - let list = multiplicands - .into_iter() - .map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) - .collect(); - - end_timer!(start); - (list, sum) -} - // This function build the eq(x, r) polynomial for any given r. // // Evaluate diff --git a/hyperplonk/Cargo.toml b/hyperplonk/Cargo.toml index 200eb58..042f99e 100644 --- a/hyperplonk/Cargo.toml +++ b/hyperplonk/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -poly-iop = { path = "../poly-iop" } ark-std = { version = "^0.3.0", default-features = false } ark-ec = { version = "^0.3.0", default-features = false } @@ -15,33 +14,49 @@ ark-poly = { version = "^0.3.0", default-features = false } ark-serialize = { version = "^0.3.0", default-features = false, features = [ "derive" ] } displaydoc = { version = "0.2.3", default-features = false } + +poly-iop = { path = "../poly-iop" } +pcs = { path = "../pcs" } transcript = { path = "../transcript" } arithmetic = { path = "../arithmetic" } +util = { path = "../util" } -jf-primitives = { git = "https://github.com/EspressoSystems/jellyfish", rev = "ff43209" } +rayon = { version = "1.5.2", default-features = false, optional = true } [dev-dependencies] ark-bls12-381 = { version = "0.3.0", default-features = false, features = [ "curve" ] } +# Benchmarks +[[bench]] +name = "hyperplonk-benches" +path = "benches/bench.rs" +harness = false [features] -# default = [ "parallel", "print-trace", "extensive_sanity_checks" ] -default = [ "parallel", "extensive_sanity_checks" ] - +# default = [ ] +default = [ "parallel" ] +# default = [ "parallel", "print-trace" ] +# default = [ "parallel", "extensive_sanity_checks" ] +bench = [ "parallel" ] # extensive sanity checks that are useful for debugging -extensive_sanity_checks = [ ] +extensive_sanity_checks = [ + "poly-iop/extensive_sanity_checks", + "pcs/extensive_sanity_checks", + ] parallel = [ + "rayon", "ark-std/parallel", "ark-ff/parallel", "ark-poly/parallel", "ark-ec/parallel", + "poly-iop/parallel", "arithmetic/parallel", - "jf-primitives/parallel", + "pcs/parallel", + "util/parallel" ] print-trace = [ "ark-std/print-trace", "poly-iop/print-trace", "arithmetic/print-trace", - "jf-primitives/print-trace", ] \ No newline at end of file diff --git a/hyperplonk/benches/bench.rs b/hyperplonk/benches/bench.rs new file mode 100644 index 0000000..d819230 --- /dev/null +++ b/hyperplonk/benches/bench.rs @@ -0,0 +1,149 @@ +use std::{env, fs::File, time::Instant}; + +use ark_bls12_381::{Bls12_381, Fr}; +use ark_serialize::Write; +use ark_std::test_rng; +use hyperplonk::{ + prelude::{CustomizedGates, HyperPlonkErrors, MockCircuit}, + HyperPlonkSNARK, +}; +use pcs::{ + prelude::{MultilinearKzgPCS, MultilinearUniversalParams, UnivariateUniversalParams}, + PolynomialCommitmentScheme, +}; +use poly_iop::PolyIOP; +use rayon::ThreadPoolBuilder; + +fn main() -> Result<(), HyperPlonkErrors> { + let args: Vec = env::args().collect(); + let thread = args[1].parse().unwrap_or(12); + + ThreadPoolBuilder::new() + .num_threads(thread) + .build_global() + .unwrap(); + bench_vanilla_plonk(thread)?; + for degree in [1, 2, 4, 8, 16, 32] { + bench_high_degree_plonk(degree, thread)?; + } + + Ok(()) +} + +fn bench_vanilla_plonk(thread: usize) -> Result<(), HyperPlonkErrors> { + let mut rng = test_rng(); + let pcs_srs = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 22)?; + + let filename = format!("vanilla nv {}.txt", thread); + let mut file = File::create(filename).unwrap(); + for nv in 1..16 { + let vanilla_gate = CustomizedGates::vanilla_plonk_gate(); + bench_mock_circuit_zkp_helper(&mut file, nv, &vanilla_gate, &pcs_srs)?; + } + + Ok(()) +} + +fn bench_high_degree_plonk(degree: usize, thread: usize) -> Result<(), HyperPlonkErrors> { + let mut rng = test_rng(); + let pcs_srs = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 22)?; + + let filename = format!("high degree {} thread {}.txt", degree, thread); + let mut file = File::create(filename).unwrap(); + for nv in 1..16 { + let vanilla_gate = CustomizedGates::vanilla_plonk_gate(); + bench_mock_circuit_zkp_helper(&mut file, nv, &vanilla_gate, &pcs_srs)?; + } + + Ok(()) +} + +fn bench_mock_circuit_zkp_helper( + file: &mut File, + nv: usize, + gate: &CustomizedGates, + pcs_srs: &( + MultilinearUniversalParams, + UnivariateUniversalParams, + ), +) -> Result<(), HyperPlonkErrors> { + let repetition = if nv < 10 { + 10 + } else if nv < 20 { + 5 + } else { + 2 + }; + + //========================================================== + let start = Instant::now(); + for _ in 0..repetition { + let circuit = MockCircuit::::new(1 << nv, gate); + assert!(circuit.is_satisfied()); + } + println!( + "mock circuit gen for {} variables: {} ns", + nv, + start.elapsed().as_nanos() / repetition as u128 + ); + + let circuit = MockCircuit::::new(1 << nv, gate); + assert!(circuit.is_satisfied()); + let index = circuit.index; + //========================================================== + // generate pk and vks + let start = Instant::now(); + for _ in 0..repetition { + let (_pk, _vk) = as HyperPlonkSNARK< + Bls12_381, + MultilinearKzgPCS, + >>::preprocess(&index, &pcs_srs)?; + } + println!( + "key extraction for {} variables: {} us", + nv, + start.elapsed().as_micros() / repetition as u128 + ); + let (pk, vk) = + as HyperPlonkSNARK>>::preprocess( + &index, &pcs_srs, + )?; + //========================================================== + // generate a proof + let start = Instant::now(); + for _ in 0..repetition { + let _proof = + as HyperPlonkSNARK>>::prove( + &pk, + &circuit.witnesses[0].coeff_ref(), + &circuit.witnesses, + )?; + } + let t = start.elapsed().as_micros() / repetition as u128; + println!("proving for {} variables: {} us", nv, t); + file.write_all(format!("{} {}\n", nv, t).as_ref()).unwrap(); + + let proof = as HyperPlonkSNARK>>::prove( + &pk, + &circuit.witnesses[0].coeff_ref(), + &circuit.witnesses, + )?; + //========================================================== + // verify a proof + let start = Instant::now(); + for _ in 0..repetition { + let verify = + as HyperPlonkSNARK>>::verify( + &vk, + &circuit.witnesses[0].coeff_ref(), + &proof, + )?; + assert!(verify); + } + println!( + "verifying for {} variables: {} us", + nv, + start.elapsed().as_micros() / repetition as u128 + ); + Ok(()) +} diff --git a/hyperplonk/src/custom_gate.rs b/hyperplonk/src/custom_gate.rs new file mode 100644 index 0000000..1586963 --- /dev/null +++ b/hyperplonk/src/custom_gate.rs @@ -0,0 +1,158 @@ +use ark_std::cmp::max; + +/// Customized gate is a list of tuples of +/// (coefficient, selector_index, wire_indices) +/// +/// Example: +/// q_L(X) * W_1(X)^5 - W_2(X) = 0 +/// is represented as +/// vec![ +/// ( 1, Some(id_qL), vec![id_W1, id_W1, id_W1, id_W1, id_W1]), +/// (-1, None, vec![id_W2]) +/// ] +/// +/// CustomizedGates { +/// gates: vec![ +/// (1, Some(0), vec![0, 0, 0, 0, 0]), +/// (-1, None, vec![1]) +/// ], +/// }; +/// where id_qL = 0 // first selector +/// id_W1 = 0 // first witness +/// id_w2 = 1 // second witness +/// +/// NOTE: here coeff is a signed integer, instead of a field element +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct CustomizedGates { + pub(crate) gates: Vec<(i64, Option, Vec)>, +} + +impl CustomizedGates { + /// The degree of the algebraic customized gate + pub fn degree(&self) -> usize { + let mut res = 0; + for x in self.gates.iter() { + res = max(res, x.2.len() + (x.1.is_some() as usize)) + } + res + } + + /// The number of selectors in a customized gate + pub fn num_selector_columns(&self) -> usize { + let mut res = 0; + for (_coeff, q, _ws) in self.gates.iter() { + // a same selector must not be used for multiple monomials. + if q.is_some() { + res += 1; + } + } + res + } + + /// The number of witnesses in a customized gate + pub fn num_witness_columns(&self) -> usize { + let mut res = 0; + for (_coeff, _q, ws) in self.gates.iter() { + // witness list must be ordered + // so we just need to compare with the last one + if let Some(&p) = ws.last() { + if res < p { + res = p + } + } + } + // add one here because index starts from 0 + res + 1 + } + + /// Return a vanilla plonk gate: + /// ``` ignore + /// q_L w_1 + q_R w_2 + q_O w_3 + q_M w1w2 + q_C = 0 + /// ``` + /// which is + /// ``` ignore + /// (1, Some(id_qL), vec![id_W1]), + /// (1, Some(id_qR), vec![id_W2]), + /// (1, Some(id_qO), vec![id_W3]), + /// (1, Some(id_qM), vec![id_W1, id_w2]), + /// (1, Some(id_qC), vec![]), + /// ``` + pub fn vanilla_plonk_gate() -> Self { + Self { + gates: vec![ + (1, Some(0), vec![0]), + (1, Some(1), vec![1]), + (1, Some(2), vec![2]), + (1, Some(3), vec![0, 1]), + (1, Some(4), vec![]), + ], + } + } + + /// Return a jellyfish turbo plonk gate: + /// ```ignore + /// q_1 w_1 + q_2 w_2 + q_3 w_3 + q_4 w4 + /// + q_M1 w1w2 + q_M2 w3w4 + /// + q_H1 w1^5 + q_H2 w2^5 + q_H3 w1^5 + q_H4 w2^5 + /// + q_E w1w2w3w4 + /// + q_O w5 + /// + q_C + /// = 0 + /// ``` + /// with + /// - w = [w1, w2, w3, w4, w5] + /// - q = [ q_1, q_2, q_3, q_4, q_M1, q_M2, q_H1, q_H2, q_H3, q_H4, q_E, + /// q_O, q_c ] + /// + /// which is + /// ```ignore + /// (1, Some(q[0]), vec![w[0]]), + /// (1, Some(q[1]), vec![w[1]]), + /// (1, Some(q[2]), vec![w[2]]), + /// (1, Some(q[3]), vec![w[3]]), + /// (1, Some(q[4]), vec![w[0], w[1]]), + /// (1, Some(q[5]), vec![w[2], w[3]]), + /// (1, Some(q[6]), vec![w[0], w[0], w[0], w[0], w[0]]), + /// (1, Some(q[7]), vec![w[1], w[1], w[1], w[1], w[1]]), + /// (1, Some(q[8]), vec![w[2], w[2], w[2], w[2], w[2]]), + /// (1, Some(q[9]), vec![w[3], w[3], w[3], w[3], w[3]]), + /// (1, Some(q[10]), vec![w[0], w[1], w[2], w[3]]), + /// (1, Some(q[11]), vec![w[4]]), + /// (1, Some(q[12]), vec![]), + /// ``` + pub fn jellyfish_turbo_plonk_gate() -> Self { + CustomizedGates { + gates: vec![ + (1, Some(0), vec![0]), + (1, Some(1), vec![1]), + (1, Some(2), vec![2]), + (1, Some(3), vec![3]), + (1, Some(4), vec![0, 1]), + (1, Some(5), vec![2, 3]), + (1, Some(6), vec![0, 0, 0, 0, 0]), + (1, Some(7), vec![1, 1, 1, 1, 1]), + (1, Some(8), vec![2, 2, 2, 2, 2]), + (1, Some(9), vec![3, 3, 3, 3, 3]), + (1, Some(10), vec![0, 1, 2, 3]), + (1, Some(11), vec![4]), + (1, Some(12), vec![]), + ], + } + } + + /// Generate a random gate for `num_witness` with a highest degree = + /// `degree` + pub fn mock_gate(num_witness: usize, degree: usize) -> Self { + let mut gates = vec![]; + + let high_degree_term = vec![0; degree]; + + gates.push((1, Some(0), high_degree_term)); + for i in 1..num_witness { + gates.push((1, Some(i), vec![i])) + } + gates.push((1, Some(num_witness), vec![])); + + CustomizedGates { gates } + } +} diff --git a/hyperplonk/src/errors.rs b/hyperplonk/src/errors.rs index 20caa25..83a17a3 100644 --- a/hyperplonk/src/errors.rs +++ b/hyperplonk/src/errors.rs @@ -4,9 +4,9 @@ use arithmetic::ArithErrors; use ark_serialize::SerializationError; use ark_std::string::String; use displaydoc::Display; -use jf_primitives::pcs::prelude::PCSError; +use pcs::prelude::PCSError; use poly_iop::prelude::PolyIOPErrors; -use transcript::TranscriptErrors; +use transcript::TranscriptError; /// A `enum` specifying the possible failure modes of hyperplonk. #[derive(Display, Debug)] @@ -24,9 +24,9 @@ pub enum HyperPlonkErrors { /// PolyIOP error {0} PolyIOPErrors(PolyIOPErrors), /// PCS error {0} - PCSError(PCSError), + PCSErrors(PCSError), /// Transcript error {0} - TranscriptError(TranscriptErrors), + TranscriptError(TranscriptError), /// Arithmetic Error: {0} ArithmeticErrors(ArithErrors), } @@ -45,12 +45,12 @@ impl From for HyperPlonkErrors { impl From for HyperPlonkErrors { fn from(e: PCSError) -> Self { - Self::PCSError(e) + Self::PCSErrors(e) } } -impl From for HyperPlonkErrors { - fn from(e: TranscriptErrors) -> Self { +impl From for HyperPlonkErrors { + fn from(e: TranscriptError) -> Self { Self::TranscriptError(e) } } diff --git a/hyperplonk/src/lib.rs b/hyperplonk/src/lib.rs index 1ad18d7..4df019e 100644 --- a/hyperplonk/src/lib.rs +++ b/hyperplonk/src/lib.rs @@ -1,27 +1,17 @@ //! Main module for the HyperPlonk SNARK. -use crate::utils::{eval_f, prove_sanity_check}; -use arithmetic::VPAuxInfo; use ark_ec::PairingEngine; -use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; -use ark_std::{ - borrow::Borrow, end_timer, log2, marker::PhantomData, rc::Rc, start_timer, One, Zero, -}; use errors::HyperPlonkErrors; -use jf_primitives::pcs::prelude::{ - compute_qx_degree, merge_polynomials, PCSError, PolynomialCommitmentScheme, -}; -use poly_iop::{ - prelude::{identity_permutation_mle, PermutationCheck, ZeroCheck}, - PolyIOP, -}; -use structs::{HyperPlonkIndex, HyperPlonkProof, HyperPlonkProvingKey, HyperPlonkVerifyingKey}; -use transcript::IOPTranscript; -use utils::{build_f, gen_eval_point}; +use pcs::prelude::PolynomialCommitmentScheme; +use poly_iop::prelude::PermutationCheck; use witness::WitnessColumn; +mod custom_gate; mod errors; +mod mock; +pub mod prelude; mod selectors; +mod snark; mod structs; mod utils; mod witness; @@ -50,7 +40,7 @@ where /// polynomial commitments fn preprocess( index: &Self::Index, - pcs_srs: impl Borrow, + pcs_srs: &PCS::SRS, ) -> Result<(Self::ProvingKey, Self::VerifyingKey), HyperPlonkErrors>; /// Generate HyperPlonk SNARK proof. @@ -62,7 +52,7 @@ where /// Outputs: /// - The HyperPlonk SNARK proof. fn prove( - pk: impl Borrow, + pk: &Self::ProvingKey, pub_input: &[E::Fr], witnesses: &[WitnessColumn], ) -> Result; @@ -76,974 +66,8 @@ where /// Outputs: /// - Return a boolean on whether the verification is successful fn verify( - vk: impl Borrow, + vk: &Self::VerifyingKey, pub_input: &[E::Fr], proof: &Self::Proof, ) -> Result; } - -impl HyperPlonkSNARK for PolyIOP -where - E: PairingEngine, - // Ideally we want to access polynomial as PCS::Polynomial, instead of instantiating it here. - // But since PCS::Polynomial can be both univariate or multivariate in our implementation - // we cannot bound PCS::Polynomial with a property trait bound. - PCS: PolynomialCommitmentScheme< - E, - Polynomial = Rc>, - Point = Vec, - Evaluation = E::Fr, - >, -{ - type Index = HyperPlonkIndex; - type ProvingKey = HyperPlonkProvingKey; - type VerifyingKey = HyperPlonkVerifyingKey; - type Proof = HyperPlonkProof; - - fn preprocess( - index: &Self::Index, - pcs_srs: impl Borrow, - ) -> Result<(Self::ProvingKey, Self::VerifyingKey), HyperPlonkErrors> { - let num_vars = index.params.nv; - let log_num_witness_polys = index.params.log_n_wires; - - // number of variables in merged polynomial for Multilinear-KZG - let merged_nv = num_vars + log_num_witness_polys; - // degree of q(x) for Univariate-KZG - let supported_uni_degree = compute_qx_degree(num_vars, 1 << log_num_witness_polys); - - // extract PCS prover and verifier keys from SRS - let (pcs_prover_param, pcs_verifier_param) = PCS::trim( - pcs_srs, - log2(supported_uni_degree) as usize, - Some(merged_nv + 1), - )?; - - // build permutation oracles - let permutation_oracle = Rc::new(DenseMultilinearExtension::from_evaluations_slice( - merged_nv, - &index.permutation, - )); - let perm_com = PCS::commit(&pcs_prover_param, &permutation_oracle)?; - - // build selector oracles and commit to it - let selector_oracles: Vec>> = index - .selectors - .iter() - .map(|s| Rc::new(DenseMultilinearExtension::from(s))) - .collect(); - - let selector_com = selector_oracles - .iter() - .map(|poly| PCS::commit(&pcs_prover_param, poly)) - .collect::, PCSError>>()?; - - Ok(( - Self::ProvingKey { - params: index.params.clone(), - permutation_oracle, - selector_oracles, - pcs_param: pcs_prover_param, - }, - Self::VerifyingKey { - params: index.params.clone(), - pcs_param: pcs_verifier_param, - selector_com, - perm_com, - }, - )) - } - - /// Generate HyperPlonk SNARK proof. - /// - /// Inputs: - /// - `pk`: circuit proving key - /// - `pub_input`: online public input of length 2^\ell - /// - `witness`: witness assignment of length 2^n - /// Outputs: - /// - The HyperPlonk SNARK proof. - /// - /// Steps: - /// - /// 1. Commit Witness polynomials `w_i(x)` and append commitment to - /// transcript - /// - /// 2. Run ZeroCheck on - /// - /// `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` - /// - /// where `f` is the constraint polynomial i.e., - /// ```ignore - /// f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) - /// = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) - /// ``` - /// in vanilla plonk, and obtain a ZeroCheckSubClaim - /// - /// 3. Run permutation check on `\{w_i(x)\}` and `permutation_oracle`, and - /// obtain a PermCheckSubClaim. - /// - /// 4. Generate evaluations and corresponding proofs - /// - permutation check evaluations and proofs - /// - zero check evaluations and proofs - /// - public input consistency checks - /// - /// TODO: this function is gigantic -- refactor it to smaller ones - fn prove( - pk: impl Borrow, - pub_input: &[E::Fr], - witnesses: &[WitnessColumn], - ) -> Result { - let pk = pk.borrow(); - let start = start_timer!(|| "hyperplonk proving"); - let mut transcript = IOPTranscript::::new(b"hyperplonk"); - - prove_sanity_check(&pk.params, pub_input, witnesses)?; - - // witness assignment of length 2^n - let num_vars = pk.params.nv; - let log_num_witness_polys = pk.params.log_n_wires; - // number of variables in merged polynomial for Multilinear-KZG - let merged_nv = num_vars + log_num_witness_polys; - // degree of q(x) for Univariate-KZG - let _supported_uni_degree = compute_qx_degree(num_vars, 1 << log_num_witness_polys); - // online public input of length 2^\ell - let ell = pk.params.log_pub_input_len; - - let witness_polys: Vec>> = witnesses - .iter() - .map(|w| Rc::new(DenseMultilinearExtension::from(w))) - .collect(); - let pi_poly = Rc::new(DenseMultilinearExtension::from_evaluations_slice( - ell as usize, - pub_input, - )); - - // ======================================================================= - // 1. Commit Witness polynomials `w_i(x)` and append commitment to - // transcript - // ======================================================================= - let step = start_timer!(|| "commit witnesses"); - // TODO(Chengyu): update `merge_polynomials` method in jellyfish repo. - let w_merged = Rc::new(merge_polynomials(&witness_polys)?); - if w_merged.num_vars != merged_nv { - return Err(HyperPlonkErrors::InvalidParameters(format!( - "merged witness poly has a different num_vars ({}) from expected ({})", - w_merged.num_vars, merged_nv - ))); - } - let w_merged_com = PCS::commit(&pk.pcs_param, &w_merged)?; - - transcript.append_serializable_element(b"w", &w_merged_com)?; - end_timer!(step); - - // ======================================================================= - // 2 Run ZeroCheck on - // - // `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` - // - // where `f` is the constraint polynomial i.e., - // - // f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) - // = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) - // - // in vanilla plonk, and obtain a ZeroCheckSubClaim - // ======================================================================= - let step = start_timer!(|| "ZeroCheck on f"); - - let fx = build_f( - &pk.params.gate_func, - pk.params.nv, - &pk.selector_oracles, - &witness_polys, - )?; - - let zero_check_proof = >::prove(&fx, &mut transcript)?; - end_timer!(step); - - // ======================================================================= - // 3. Run permutation check on `\{w_i(x)\}` and `permutation_oracle`, and - // obtain a PermCheckSubClaim. - // ======================================================================= - let step = start_timer!(|| "Permutation check on w_i(x)"); - - let (perm_check_proof, prod_x) = >::prove( - &pk.pcs_param, - &w_merged, - &w_merged, - &pk.permutation_oracle, - &mut transcript, - )?; - - // open prod(0,x), prod(1, x), prod(x, 0), prod(x, 1) at zero_check.point - // prod(0, x) - let tmp_point = [ - perm_check_proof.zero_check_proof.point.as_slice(), - &[E::Fr::zero()], - ] - .concat(); - let (prod_0_x_opening, prod_0_x_eval) = PCS::open(&pk.pcs_param, &prod_x, &tmp_point)?; - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - let eval = prod_x.evaluate(&tmp_point).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "prod_0_x evaluation dimension does not match".to_string(), - ) - })?; - if eval != prod_0_x_eval { - return Err(HyperPlonkErrors::InvalidProver( - "prod_0_x evaluation is different from PCS opening".to_string(), - )); - } - } - // prod(1, x) - let tmp_point = [ - perm_check_proof.zero_check_proof.point.as_slice(), - &[E::Fr::one()], - ] - .concat(); - let (prod_1_x_opening, prod_1_x_eval) = PCS::open(&pk.pcs_param, &prod_x, &tmp_point)?; - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - let eval = prod_x.evaluate(&tmp_point).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "prod_1_x evaluation dimension does not match".to_string(), - ) - })?; - if eval != prod_1_x_eval { - return Err(HyperPlonkErrors::InvalidProver( - "prod_1_x evaluation is different from PCS opening".to_string(), - )); - } - } - // prod(x, 0) - let tmp_point = [ - &[E::Fr::zero()], - perm_check_proof.zero_check_proof.point.as_slice(), - ] - .concat(); - let (prod_x_0_opening, prod_x_0_eval) = PCS::open(&pk.pcs_param, &prod_x, &tmp_point)?; - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - let eval = prod_x.evaluate(&tmp_point).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "prod_x_0 evaluation dimension does not match".to_string(), - ) - })?; - - if eval != prod_x_0_eval { - return Err(HyperPlonkErrors::InvalidProver( - "prod_x_0 evaluation is different from PCS opening".to_string(), - )); - } - } - // prod(x, 1) - let tmp_point = [ - &[E::Fr::one()], - perm_check_proof.zero_check_proof.point.as_slice(), - ] - .concat(); - let (prod_x_1_opening, prod_x_1_eval) = PCS::open(&pk.pcs_param, &prod_x, &tmp_point)?; - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - let eval = prod_x.evaluate(&tmp_point).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "prod_x_1 evaluation dimension does not match".to_string(), - ) - })?; - if eval != prod_x_1_eval { - return Err(HyperPlonkErrors::InvalidProver( - "prod_x_1 evaluation is different from PCS opening".to_string(), - )); - } - } - // prod(1, ..., 1, 0) - let tmp_point = [vec![E::Fr::zero()], vec![E::Fr::one(); merged_nv]].concat(); - let (prod_1_0_opening, prod_1_0_eval) = PCS::open(&pk.pcs_param, &prod_x, &tmp_point)?; - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - if prod_1_0_eval != E::Fr::one() { - return Err(HyperPlonkErrors::InvalidProver(format!( - "prod_1_0 evaluation is not one: got {}", - prod_1_0_eval, - ))); - } - } - end_timer!(step); - - // ======================================================================= - // 4. Generate evaluations and corresponding proofs - // - permutation check evaluations and proofs - // - wi_poly(r_perm_check) where r_perm_check is from perm_check_proof - // - selector_poly(r_perm_check) - // - // - zero check evaluations and proofs - // - wi_poly(r_zero_check) where r_zero_check is from zero_check_proof - // - selector_poly(r_zero_check) - // - // - public input consistency checks - // - pi_poly(r_pi) where r_pi is sampled from transcript - // ======================================================================= - let step = start_timer!(|| "opening and evaluations"); - - // 4.1 permutation check - let mut witness_zero_check_evals = vec![]; - let mut witness_zero_check_openings = vec![]; - // TODO: parallelization - // TODO: Batch opening - - // open permutation check proof - let (witness_perm_check_opening, witness_perm_check_eval) = PCS::open( - &pk.pcs_param, - &w_merged, - &perm_check_proof.zero_check_proof.point, - )?; - - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity checks - let eval = w_merged - .evaluate(&perm_check_proof.zero_check_proof.point) - .ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "witness_perm_check evaluation dimension does not match".to_string(), - ) - })?; - if eval != witness_perm_check_eval { - return Err(HyperPlonkErrors::InvalidProver( - "witness_perm_check evaluation is different from PCS opening".to_string(), - )); - } - } - - // 4.2 open zero check proof - // TODO: batch opening - for (i, wire_poly) in witness_polys.iter().enumerate() { - let tmp_point = gen_eval_point(i, log_num_witness_polys, &zero_check_proof.point); - // Open zero check proof - let (zero_proof, zero_eval) = PCS::open(&pk.pcs_param, &w_merged, &tmp_point)?; - { - let eval = wire_poly.evaluate(&zero_check_proof.point).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "witness_zero_check evaluation dimension does not match".to_string(), - ) - })?; - if eval != zero_eval { - return Err(HyperPlonkErrors::InvalidProver( - "witness_zero_check evaluation is different from PCS opening".to_string(), - )); - } - } - witness_zero_check_evals.push(zero_eval); - witness_zero_check_openings.push(zero_proof); - } - - // Open permutation polynomial at perm_check_point - let (s_perm_opening, s_perm_eval) = PCS::open( - &pk.pcs_param, - &pk.permutation_oracle, - &perm_check_proof.zero_check_proof.point, - )?; - - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - let eval = pk - .permutation_oracle - .evaluate(&perm_check_proof.zero_check_proof.point) - .ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "perm_oracle evaluation dimension does not match".to_string(), - ) - })?; - if eval != s_perm_eval { - return Err(HyperPlonkErrors::InvalidProver( - "perm_oracle evaluation is different from PCS opening".to_string(), - )); - } - } - - // Open selector polynomial at zero_check_point - let mut selector_oracle_openings = vec![]; - let mut selector_oracle_evals = vec![]; - - // TODO: parallelization - for selector_poly in pk.selector_oracles.iter() { - // Open zero check proof - // during verification, use this eval against subclaim - let (zero_proof, zero_eval) = - PCS::open(&pk.pcs_param, selector_poly, &zero_check_proof.point)?; - - #[cfg(feature = "extensive_sanity_checks")] - { - let eval = selector_poly - .evaluate(&zero_check_proof.point) - .ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "selector evaluation dimension does not match".to_string(), - ) - })?; - if eval != zero_eval { - return Err(HyperPlonkErrors::InvalidProver( - "selector evaluation is different from PCS opening".to_string(), - )); - } - } - selector_oracle_openings.push(zero_proof); - selector_oracle_evals.push(zero_eval); - } - - // 4.3 public input consistency checks - let r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?; - let tmp_point = [ - vec![E::Fr::zero(); num_vars - ell], - r_pi.clone(), - vec![E::Fr::zero(); log_num_witness_polys], - ] - .concat(); - let (pi_opening, pi_eval) = PCS::open(&pk.pcs_param, &w_merged, &tmp_point)?; - - #[cfg(feature = "extensive_sanity_checks")] - { - // sanity check - let eval = pi_poly.evaluate(&r_pi).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters( - "public input evaluation dimension does not match".to_string(), - ) - })?; - if eval != pi_eval { - return Err(HyperPlonkErrors::InvalidProver( - "public input evaluation is different from PCS opening".to_string(), - )); - } - } - - end_timer!(step); - end_timer!(start); - - Ok(HyperPlonkProof { - // ======================================================================= - // PCS components: common - // ======================================================================= - w_merged_com, - // ======================================================================= - // PCS components: permutation check - // ======================================================================= - // We do not validate prod(x), this is checked by subclaim - prod_evals: vec![prod_0_x_eval, prod_1_x_eval, prod_x_0_eval, prod_x_1_eval], - prod_openings: vec![ - prod_0_x_opening, - prod_1_x_opening, - prod_x_0_opening, - prod_x_1_opening, - prod_1_0_opening, - ], - witness_perm_check_opening, - witness_perm_check_eval, - perm_oracle_opening: s_perm_opening, - perm_oracle_eval: s_perm_eval, - // ======================================================================= - // PCS components: zero check - // ======================================================================= - witness_zero_check_openings, - witness_zero_check_evals, - selector_oracle_openings, - selector_oracle_evals, - // ======================================================================= - // PCS components: public inputs - // ======================================================================= - pi_eval, - pi_opening, - // ======================================================================= - // IOP components - // ======================================================================= - zero_check_proof, - perm_check_proof, - }) - } - - /// Verify the HyperPlonk proof. - /// - /// Inputs: - /// - `vk`: verification key - /// - `pub_input`: online public input - /// - `proof`: HyperPlonk SNARK proof - /// Outputs: - /// - Return a boolean on whether the verification is successful - /// - /// 1. Verify zero_check_proof on - /// - /// `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` - /// - /// where `f` is the constraint polynomial i.e., - /// ```ignore - /// f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) - /// = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) - /// ``` - /// in vanilla plonk, and obtain a ZeroCheckSubClaim - /// - /// 2. Verify perm_check_proof on `\{w_i(x)\}` and `permutation_oracle` - /// - /// 3. check subclaim validity - /// - /// 4. Verify the opening against the commitment: - /// - check permutation check evaluations - /// - check zero check evaluations - /// - public input consistency checks - fn verify( - vk: impl Borrow, - pub_input: &[E::Fr], - proof: &Self::Proof, - ) -> Result { - let vk = vk.borrow(); - let start = start_timer!(|| "hyperplonk verification"); - - let mut transcript = IOPTranscript::::new(b"hyperplonk"); - // witness assignment of length 2^n - let num_vars = vk.params.nv; - let log_num_witness_polys = vk.params.log_n_wires; - // number of variables in merged polynomial for Multilinear-KZG - let merged_nv = num_vars + log_num_witness_polys; - - // online public input of length 2^\ell - let ell = vk.params.log_pub_input_len; - - let pi_poly = DenseMultilinearExtension::from_evaluations_slice(ell as usize, pub_input); - - // ======================================================================= - // 0. sanity checks - // ======================================================================= - // public input length - if pub_input.len() != 1 << ell { - return Err(HyperPlonkErrors::InvalidProver(format!( - "Public input length is not correct: got {}, expect {}", - pub_input.len(), - 1 << ell - ))); - } - if proof.selector_oracle_evals.len() != 1 << vk.params.log_n_selectors { - return Err(HyperPlonkErrors::InvalidProver(format!( - "Selector length is not correct: got {}, expect {}", - proof.selector_oracle_evals.len(), - 1 << vk.params.log_n_selectors - ))); - } - if proof.witness_zero_check_evals.len() != 1 << log_num_witness_polys { - return Err(HyperPlonkErrors::InvalidProver(format!( - "Witness length is not correct: got {}, expect {}", - proof.witness_zero_check_evals.len(), - 1 << log_num_witness_polys - ))); - } - if proof.prod_openings.len() != 5 { - return Err(HyperPlonkErrors::InvalidProver(format!( - "the number of product polynomial evaluations is not correct: got {}, expect {}", - proof.prod_openings.len(), - 5 - ))); - } - - // ======================================================================= - // 1. Verify zero_check_proof on - // `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` - // - // where `f` is the constraint polynomial i.e., - // - // f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) - // = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) - // - // ======================================================================= - let step = start_timer!(|| "verify zero check"); - // Zero check and perm check have different AuxInfo - let zero_check_aux_info = VPAuxInfo:: { - max_degree: vk.params.gate_func.degree(), - num_variables: num_vars, - phantom: PhantomData::default(), - }; - - // push witness to transcript - transcript.append_serializable_element(b"w", &proof.w_merged_com)?; - - let zero_check_sub_claim = >::verify( - &proof.zero_check_proof, - &zero_check_aux_info, - &mut transcript, - )?; - - let zero_check_point = &zero_check_sub_claim.point; - - // check zero check subclaim - let f_eval = eval_f( - &vk.params.gate_func, - &proof.selector_oracle_evals, - &proof.witness_zero_check_evals, - )?; - if f_eval != zero_check_sub_claim.expected_evaluation { - return Err(HyperPlonkErrors::InvalidProof( - "zero check evaluation failed".to_string(), - )); - } - - end_timer!(step); - // ======================================================================= - // 2. Verify perm_check_proof on `\{w_i(x)\}` and `permutation_oracle` - // ======================================================================= - let step = start_timer!(|| "verify permutation check"); - - // Zero check and perm check have different AuxInfo - let perm_check_aux_info = VPAuxInfo:: { - // Prod(x) has a max degree of 2 - max_degree: 2, - // degree of merged poly - num_variables: merged_nv, - phantom: PhantomData::default(), - }; - let perm_check_sub_claim = >::verify( - &proof.perm_check_proof, - &perm_check_aux_info, - &mut transcript, - )?; - - let perm_check_point = &perm_check_sub_claim - .product_check_sub_claim - .zero_check_sub_claim - .point; - - let alpha = perm_check_sub_claim.product_check_sub_claim.alpha; - let (beta, gamma) = perm_check_sub_claim.challenges; - - // check perm check subclaim: - // proof.witness_perm_check_eval ?= perm_check_sub_claim.expected_eval - // - // Q(x) := prod(1,x) - prod(x, 0) * prod(x, 1) - // + alpha * ( - // (g(x) + beta * s_perm(x) + gamma) * prod(0, x) - // - (f(x) + beta * s_id(x) + gamma)) - // where - // - Q(x) is perm_check_sub_claim.zero_check.exp_eval - // - prod(1, x) ... from prod(x) evaluated over (1, zero_point) - // - g(x), f(x) are both w_merged over (zero_point) - // - s_perm(x) and s_id(x) from vk_param.perm_oracle - // - alpha, beta, gamma from challenge - - let s_id = identity_permutation_mle::(perm_check_point.len()); - let s_id_eval = s_id.evaluate(perm_check_point).ok_or_else(|| { - HyperPlonkErrors::InvalidVerifier("unable to evaluate s_id(x)".to_string()) - })?; - - let q_x_rec = proof.prod_evals[1] - proof.prod_evals[2] * proof.prod_evals[3] - + alpha - * ((proof.witness_perm_check_eval + beta * proof.perm_oracle_eval + gamma) - * proof.prod_evals[0] - - (proof.witness_perm_check_eval + beta * s_id_eval + gamma)); - - if q_x_rec - != perm_check_sub_claim - .product_check_sub_claim - .zero_check_sub_claim - .expected_evaluation - { - return Err(HyperPlonkErrors::InvalidVerifier( - "evaluation failed".to_string(), - )); - } - - end_timer!(step); - // ======================================================================= - // 3. Verify the opening against the commitment - // ======================================================================= - let step = start_timer!(|| "verify commitments"); - - // ======================================================================= - // 3.1 check permutation check evaluations - // ======================================================================= - // witness for permutation check - if !PCS::verify( - &vk.pcs_param, - &proof.w_merged_com, - perm_check_point, - &proof.witness_perm_check_eval, - &proof.witness_perm_check_opening, - )? { - return Err(HyperPlonkErrors::InvalidProof( - "witness for permutation check pcs verification failed".to_string(), - )); - } - - if !PCS::verify( - &vk.pcs_param, - &vk.perm_com, - perm_check_point, - &proof.perm_oracle_eval, - &proof.perm_oracle_opening, - )? { - return Err(HyperPlonkErrors::InvalidProof( - "perm oracle pcs verification failed".to_string(), - )); - } - - // prod(x) for permutation check - // TODO: batch verification - - // prod(0, x) - if !PCS::verify( - &vk.pcs_param, - &proof.perm_check_proof.prod_x_comm, - &[perm_check_point.as_slice(), &[E::Fr::zero()]].concat(), - &proof.prod_evals[0], - &proof.prod_openings[0], - )? { - return Err(HyperPlonkErrors::InvalidProof( - "prod(0, x) pcs verification failed".to_string(), - )); - } - // prod(1, x) - if !PCS::verify( - &vk.pcs_param, - &proof.perm_check_proof.prod_x_comm, - &[perm_check_point.as_slice(), &[E::Fr::one()]].concat(), - &proof.prod_evals[1], - &proof.prod_openings[1], - )? { - return Err(HyperPlonkErrors::InvalidProof( - "prod(1, x) pcs verification failed".to_string(), - )); - } - // prod(x, 0) - if !PCS::verify( - &vk.pcs_param, - &proof.perm_check_proof.prod_x_comm, - &[&[E::Fr::zero()], perm_check_point.as_slice()].concat(), - &proof.prod_evals[2], - &proof.prod_openings[2], - )? { - return Err(HyperPlonkErrors::InvalidProof( - "prod(x, 0) pcs verification failed".to_string(), - )); - } - // prod(x, 1) - if !PCS::verify( - &vk.pcs_param, - &proof.perm_check_proof.prod_x_comm, - &[&[E::Fr::one()], perm_check_point.as_slice()].concat(), - &proof.prod_evals[3], - &proof.prod_openings[3], - )? { - return Err(HyperPlonkErrors::InvalidProof( - "prod(x, 1) pcs verification failed".to_string(), - )); - } - // prod(1, ..., 1, 0) = 1 - let prod_final_query = perm_check_sub_claim.product_check_sub_claim.final_query; - if !PCS::verify( - &vk.pcs_param, - &proof.perm_check_proof.prod_x_comm, - &prod_final_query.0, - &prod_final_query.1, - &proof.prod_openings[4], - )? { - return Err(HyperPlonkErrors::InvalidProof( - "prod(1, ..., 1, 0) pcs verification failed".to_string(), - )); - } - - // ======================================================================= - // 3.2 check zero check evaluations - // ======================================================================= - // witness for zero check - // TODO: batch verification - for (i, (opening, eval)) in proof - .witness_zero_check_openings - .iter() - .zip(proof.witness_zero_check_evals.iter()) - .enumerate() - { - let tmp_point = gen_eval_point(i, log_num_witness_polys, zero_check_point); - if !PCS::verify( - &vk.pcs_param, - &proof.w_merged_com, - &tmp_point, - eval, - opening, - )? { - return Err(HyperPlonkErrors::InvalidProof( - "witness for zero_check pcs verification failed".to_string(), - )); - } - } - - // selector for zero check - for (commitment, (opening, eval)) in vk.selector_com.iter().zip( - proof - .selector_oracle_openings - .iter() - .zip(proof.selector_oracle_evals.iter()), - ) { - if !PCS::verify(&vk.pcs_param, commitment, perm_check_point, eval, opening)? { - return Err(HyperPlonkErrors::InvalidProof( - "selector pcs verification failed".to_string(), - )); - } - } - - // ======================================================================= - // 3.3 public input consistency checks - // ======================================================================= - let mut r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?; - let pi_eval = pi_poly.evaluate(&r_pi).ok_or_else(|| { - HyperPlonkErrors::InvalidParameters("evaluation dimension does not match".to_string()) - })?; - r_pi = [ - vec![E::Fr::zero(); num_vars - ell], - r_pi, - vec![E::Fr::zero(); log_num_witness_polys], - ] - .concat(); - if !PCS::verify( - &vk.pcs_param, - &proof.w_merged_com, - &r_pi, - &pi_eval, - &proof.pi_opening, - )? { - return Err(HyperPlonkErrors::InvalidProof( - "public input pcs verification failed".to_string(), - )); - } - - end_timer!(step); - end_timer!(start); - Ok(true) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - selectors::SelectorColumn, - structs::{CustomizedGates, HyperPlonkParams}, - witness::WitnessColumn, - }; - use ark_bls12_381::Bls12_381; - use ark_std::test_rng; - use jf_primitives::pcs::prelude::MultilinearKzgPCS; - use poly_iop::prelude::random_permutation_mle; - - #[test] - fn test_hyperplonk_e2e() -> Result<(), HyperPlonkErrors> { - // Example: - // q_L(X) * W_1(X)^5 - W_2(X) = 0 - // is represented as - // vec![ - // ( 1, Some(id_qL), vec![id_W1, id_W1, id_W1, id_W1, id_W1]), - // (-1, None, vec![id_W2]) - // ] - // - // 4 public input - // 1 selector, - // 2 witnesses, - // 2 variables for MLE, - // 4 wires, - let gates = CustomizedGates { - gates: vec![(1, Some(0), vec![0, 0, 0, 0, 0]), (-1, None, vec![1])], - }; - test_hyperplonk_helper::(2, 2, 0, 1, gates) - } - - fn test_hyperplonk_helper( - nv: usize, - log_pub_input_len: usize, - log_n_selectors: usize, - log_n_wires: usize, - gate_func: CustomizedGates, - ) -> Result<(), HyperPlonkErrors> { - let mut rng = test_rng(); - let pcs_srs = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 15)?; - let merged_nv = nv + log_n_wires; - - // generate index - let params = HyperPlonkParams { - nv, - log_pub_input_len, - log_n_selectors, - log_n_wires, - gate_func, - }; - let permutation = identity_permutation_mle(merged_nv).evaluations.clone(); - let q1 = SelectorColumn(vec![E::Fr::one(), E::Fr::one(), E::Fr::one(), E::Fr::one()]); - let index = HyperPlonkIndex { - params, - permutation, - selectors: vec![q1], - }; - - // generate pk and vks - let (pk, vk) = as HyperPlonkSNARK>>::preprocess( - &index, &pcs_srs, - )?; - - // w1 := [0, 1, 2, 3] - let w1 = WitnessColumn(vec![ - E::Fr::zero(), - E::Fr::one(), - E::Fr::from(2u64), - E::Fr::from(3u64), - ]); - // w2 := [0^5, 1^5, 2^5, 3^5] - let w2 = WitnessColumn(vec![ - E::Fr::zero(), - E::Fr::one(), - E::Fr::from(32u64), - E::Fr::from(243u64), - ]); - // public input = w1 - let pi = w1.clone(); - - // generate a proof and verify - let proof = as HyperPlonkSNARK>>::prove( - &pk, - &pi.0, - &[w1.clone(), w2.clone()], - )?; - - let _verify = as HyperPlonkSNARK>>::verify( - &vk, &pi.0, &proof, - )?; - - // bad path 1: wrong permutation - let rand_perm: Vec = random_permutation_mle(merged_nv, &mut rng) - .evaluations - .clone(); - let mut bad_index = index; - bad_index.permutation = rand_perm; - // generate pk and vks - let (_, bad_vk) = as HyperPlonkSNARK>>::preprocess( - &bad_index, &pcs_srs, - )?; - assert!( - as HyperPlonkSNARK>>::verify( - &bad_vk, &pi.0, &proof, - ) - .is_err() - ); - - // bad path 2: wrong witness - let mut w1_bad = w1; - w1_bad.0[0] = E::Fr::one(); - assert!( - as HyperPlonkSNARK>>::prove( - &pk, - &pi.0, - &[w1_bad, w2], - ) - .is_err() - ); - - Ok(()) - } -} diff --git a/hyperplonk/src/mock.rs b/hyperplonk/src/mock.rs new file mode 100644 index 0000000..7539963 --- /dev/null +++ b/hyperplonk/src/mock.rs @@ -0,0 +1,235 @@ +use arithmetic::identity_permutation_mle; +use ark_ff::PrimeField; +use ark_poly::MultilinearExtension; +use ark_std::{log2, test_rng}; + +use crate::{ + custom_gate::CustomizedGates, + selectors::SelectorColumn, + structs::{HyperPlonkIndex, HyperPlonkParams}, + witness::WitnessColumn, +}; + +pub struct MockCircuit { + pub witnesses: Vec>, + pub index: HyperPlonkIndex, +} + +impl MockCircuit { + /// Number of variables in a multilinear system + pub fn num_variables(&self) -> usize { + self.index.num_variables() + } + + /// number of selector columns + pub fn num_selector_columns(&self) -> usize { + self.index.num_selector_columns() + } + + /// number of witness columns + pub fn num_witness_columns(&self) -> usize { + self.index.num_witness_columns() + } +} + +impl MockCircuit { + /// Generate a mock plonk circuit for the input constraint size. + pub fn new(num_constraints: usize, gate: &CustomizedGates) -> MockCircuit { + let mut rng = test_rng(); + let nv = log2(num_constraints); + let num_selectors = gate.num_selector_columns(); + let num_witnesses = gate.num_witness_columns(); + let log_n_wires = log2(num_witnesses); + let merged_nv = nv + log_n_wires; + + let mut selectors: Vec> = vec![SelectorColumn::default(); num_selectors]; + let mut witnesses: Vec> = vec![WitnessColumn::default(); num_witnesses]; + + for _cs_counter in 0..num_constraints { + let mut cur_selectors: Vec = (0..(num_selectors - 1)) + .map(|_| F::rand(&mut rng)) + .collect(); + let cur_witness: Vec = (0..num_witnesses).map(|_| F::rand(&mut rng)).collect(); + let mut last_selector = F::zero(); + for (index, (coeff, q, wit)) in gate.gates.iter().enumerate() { + if index != num_selectors - 1 { + let mut cur_monomial = if *coeff < 0 { + -F::from((-coeff) as u64) + } else { + F::from(*coeff as u64) + }; + cur_monomial = match q { + Some(p) => cur_monomial * cur_selectors[*p], + None => cur_monomial, + }; + for wit_index in wit.iter() { + cur_monomial *= cur_witness[*wit_index]; + } + last_selector += cur_monomial; + } else { + let mut cur_monomial = if *coeff < 0 { + -F::from((-coeff) as u64) + } else { + F::from(*coeff as u64) + }; + for wit_index in wit.iter() { + cur_monomial *= cur_witness[*wit_index]; + } + last_selector /= -cur_monomial; + } + } + cur_selectors.push(last_selector); + for i in 0..num_selectors { + selectors[i].append(cur_selectors[i]); + } + for i in 0..num_witnesses { + witnesses[i].append(cur_witness[i]); + } + } + + let params = HyperPlonkParams { + num_constraints, + num_pub_input: num_constraints, + gate_func: gate.clone(), + }; + + let permutation = identity_permutation_mle(merged_nv as usize).to_evaluations(); + let index = HyperPlonkIndex { + params, + permutation, + selectors, + }; + + Self { witnesses, index } + } + + pub fn is_satisfied(&self) -> bool { + for current_row in 0..self.num_variables() { + let mut cur = F::zero(); + for (coeff, q, wit) in self.index.params.gate_func.gates.iter() { + let mut cur_monomial = if *coeff < 0 { + -F::from((-coeff) as u64) + } else { + F::from(*coeff as u64) + }; + cur_monomial = match q { + Some(p) => cur_monomial * self.index.selectors[*p].0[current_row], + None => cur_monomial, + }; + for wit_index in wit.iter() { + cur_monomial *= self.witnesses[*wit_index].0[current_row]; + } + cur += cur_monomial; + } + if !cur.is_zero() { + return false; + } + } + + true + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{errors::HyperPlonkErrors, HyperPlonkSNARK}; + use ark_bls12_381::{Bls12_381, Fr}; + use pcs::{ + prelude::{MultilinearKzgPCS, MultilinearUniversalParams, UnivariateUniversalParams}, + PolynomialCommitmentScheme, + }; + use poly_iop::PolyIOP; + + #[test] + fn test_mock_circuit_sat() { + for i in 1..10 { + let vanilla_gate = CustomizedGates::vanilla_plonk_gate(); + let circuit = MockCircuit::::new(1 << i, &vanilla_gate); + assert!(circuit.is_satisfied()); + + let jf_gate = CustomizedGates::jellyfish_turbo_plonk_gate(); + let circuit = MockCircuit::::new(1 << i, &jf_gate); + assert!(circuit.is_satisfied()); + + for num_witness in 2..10 { + for degree in 1..10 { + let mock_gate = CustomizedGates::mock_gate(num_witness, degree); + let circuit = MockCircuit::::new(1 << i, &mock_gate); + assert!(circuit.is_satisfied()); + } + } + } + } + + fn test_mock_circuit_zkp_helper( + nv: usize, + gate: &CustomizedGates, + pcs_srs: &( + MultilinearUniversalParams, + UnivariateUniversalParams, + ), + ) -> Result<(), HyperPlonkErrors> { + let circuit = MockCircuit::::new(1 << nv, gate); + assert!(circuit.is_satisfied()); + + let index = circuit.index; + + // generate pk and vks + let (pk, vk) = + as HyperPlonkSNARK>>::preprocess( + &index, &pcs_srs, + )?; + // generate a proof and verify + let proof = + as HyperPlonkSNARK>>::prove( + &pk, + &circuit.witnesses[0].0, + &circuit.witnesses, + )?; + + let verify = + as HyperPlonkSNARK>>::verify( + &vk, + &circuit.witnesses[0].0, + &proof, + )?; + assert!(verify); + Ok(()) + } + + #[test] + fn test_mock_circuit_zkp() -> Result<(), HyperPlonkErrors> { + let mut rng = test_rng(); + let pcs_srs = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 16)?; + for nv in 1..10 { + let vanilla_gate = CustomizedGates::vanilla_plonk_gate(); + test_mock_circuit_zkp_helper(nv, &vanilla_gate, &pcs_srs)?; + } + for nv in 1..10 { + let tubro_gate = CustomizedGates::jellyfish_turbo_plonk_gate(); + test_mock_circuit_zkp_helper(nv, &tubro_gate, &pcs_srs)?; + } + let nv = 5; + for num_witness in 2..10 { + for degree in [1, 2, 4, 8, 16] { + let mock_gate = CustomizedGates::mock_gate(num_witness, degree); + test_mock_circuit_zkp_helper(nv, &mock_gate, &pcs_srs)?; + } + } + + Ok(()) + } + + #[test] + fn test_mock_circuit_e2e() -> Result<(), HyperPlonkErrors> { + let mut rng = test_rng(); + let pcs_srs = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 23)?; + let nv = 18; + + let vanilla_gate = CustomizedGates::vanilla_plonk_gate(); + test_mock_circuit_zkp_helper(nv, &vanilla_gate, &pcs_srs)?; + + Ok(()) + } +} diff --git a/hyperplonk/src/prelude.rs b/hyperplonk/src/prelude.rs new file mode 100644 index 0000000..254d29b --- /dev/null +++ b/hyperplonk/src/prelude.rs @@ -0,0 +1,4 @@ +pub use crate::{ + custom_gate::CustomizedGates, errors::HyperPlonkErrors, mock::MockCircuit, + selectors::SelectorColumn, witness::WitnessColumn, HyperPlonkSNARK, +}; diff --git a/hyperplonk/src/selectors.rs b/hyperplonk/src/selectors.rs index a08f4c5..eb6c68b 100644 --- a/hyperplonk/src/selectors.rs +++ b/hyperplonk/src/selectors.rs @@ -9,7 +9,7 @@ use std::rc::Rc; pub struct SelectorRow(pub(crate) Vec); /// A column of selectors of length `#constraints` -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct SelectorColumn(pub(crate) Vec); impl SelectorColumn { @@ -19,6 +19,11 @@ impl SelectorColumn { log2(self.0.len()) as usize } + /// Append a new element to the selector column + pub fn append(&mut self, new_element: F) { + self.0.push(new_element) + } + /// Build selector columns from rows pub fn from_selector_rows( selector_rows: &[SelectorRow], diff --git a/hyperplonk/src/snark.rs b/hyperplonk/src/snark.rs new file mode 100644 index 0000000..b0283be --- /dev/null +++ b/hyperplonk/src/snark.rs @@ -0,0 +1,829 @@ +use crate::{ + errors::HyperPlonkErrors, + structs::{HyperPlonkIndex, HyperPlonkProof, HyperPlonkProvingKey, HyperPlonkVerifyingKey}, + utils::{build_f, eval_f, prover_sanity_check, PcsAccumulator}, + witness::WitnessColumn, + HyperPlonkSNARK, +}; +use arithmetic::{ + evaluate_opt, gen_eval_point, identity_permutation_mle, merge_polynomials, VPAuxInfo, +}; +use ark_ec::PairingEngine; +use ark_poly::DenseMultilinearExtension; +use ark_std::{end_timer, log2, start_timer, One, Zero}; +use pcs::prelude::{compute_qx_degree, PolynomialCommitmentScheme}; +use poly_iop::{ + prelude::{PermutationCheck, ZeroCheck}, + PolyIOP, +}; +use std::{cmp::max, marker::PhantomData, rc::Rc}; +use transcript::IOPTranscript; + +impl HyperPlonkSNARK for PolyIOP +where + E: PairingEngine, + // Ideally we want to access polynomial as PCS::Polynomial, instead of instantiating it here. + // But since PCS::Polynomial can be both univariate or multivariate in our implementation + // we cannot bound PCS::Polynomial with a property trait bound. + PCS: PolynomialCommitmentScheme< + E, + Polynomial = Rc>, + Point = Vec, + Evaluation = E::Fr, + >, +{ + type Index = HyperPlonkIndex; + type ProvingKey = HyperPlonkProvingKey; + type VerifyingKey = HyperPlonkVerifyingKey; + type Proof = HyperPlonkProof; + + fn preprocess( + index: &Self::Index, + pcs_srs: &PCS::SRS, + ) -> Result<(Self::ProvingKey, Self::VerifyingKey), HyperPlonkErrors> { + let num_vars = index.num_variables(); + + let log_num_witness_polys = log2(index.num_witness_columns()) as usize; + let log_num_selector_polys = log2(index.num_selector_columns()) as usize; + + let witness_merged_nv = num_vars + log_num_witness_polys; + let selector_merged_nv = num_vars + log_num_selector_polys; + + let max_nv = max(witness_merged_nv + 1, selector_merged_nv); + let max_points = max( + // prod(x) has 5 points + 5, + max( + // selector points + index.num_selector_columns(), + // witness points + public input point + perm point + index.num_witness_columns() + 2, + ), + ); + + let supported_uni_degree = compute_qx_degree(max_nv, max_points); + let supported_ml_degree = max_nv; + + // extract PCS prover and verifier keys from SRS + let (pcs_prover_param, pcs_verifier_param) = + PCS::trim(pcs_srs, supported_uni_degree, Some(supported_ml_degree))?; + + // build permutation oracles + let permutation_oracle = Rc::new(DenseMultilinearExtension::from_evaluations_slice( + witness_merged_nv, + &index.permutation, + )); + let perm_com = PCS::commit(&pcs_prover_param, &permutation_oracle)?; + + // build selector oracles and commit to it + let selector_oracles: Vec>> = index + .selectors + .iter() + .map(|s| Rc::new(DenseMultilinearExtension::from(s))) + .collect(); + + let selector_merged = merge_polynomials(&selector_oracles)?; + let selector_com = PCS::commit(&pcs_prover_param, &selector_merged)?; + + Ok(( + Self::ProvingKey { + params: index.params.clone(), + permutation_oracle: permutation_oracle.clone(), + selector_oracles, + selector_com: selector_com.clone(), + pcs_param: pcs_prover_param, + }, + Self::VerifyingKey { + params: index.params.clone(), + permutation_oracle, + pcs_param: pcs_verifier_param, + selector_com, + perm_com, + }, + )) + } + + /// Generate HyperPlonk SNARK proof. + /// + /// Inputs: + /// - `pk`: circuit proving key + /// - `pub_input`: online public input of length 2^\ell + /// - `witness`: witness assignment of length 2^n + /// Outputs: + /// - The HyperPlonk SNARK proof. + /// + /// Steps: + /// + /// 1. Commit Witness polynomials `w_i(x)` and append commitment to + /// transcript + /// + /// 2. Run ZeroCheck on + /// + /// `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` + /// + /// where `f` is the constraint polynomial i.e., + /// ```ignore + /// f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) + /// = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) + /// ``` + /// in vanilla plonk, and obtain a ZeroCheckSubClaim + /// + /// 3. Run permutation check on `\{w_i(x)\}` and `permutation_oracle`, and + /// obtain a PermCheckSubClaim. + /// + /// 4. Generate evaluations and corresponding proofs + /// - 4.1. (deferred) batch opening prod(x) at + /// - [0, perm_check_point] + /// - [1, perm_check_point] + /// - [perm_check_point, 0] + /// - [perm_check_point, 1] + /// - [1,...1, 0] + /// + /// - 4.2. permutation check evaluations and proofs + /// - 4.2.1. (deferred) wi_poly(perm_check_point) + /// + /// - 4.3. zero check evaluations and proofs + /// - 4.3.1. (deferred) wi_poly(zero_check_point) + /// - 4.3.2. (deferred) selector_poly(zero_check_point) + /// + /// - 4.4. public input consistency checks + /// - pi_poly(r_pi) where r_pi is sampled from transcript + /// + /// - 5. deferred batch opening + fn prove( + pk: &Self::ProvingKey, + pub_input: &[E::Fr], + witnesses: &[WitnessColumn], + ) -> Result { + let start = start_timer!(|| "hyperplonk proving"); + let mut transcript = IOPTranscript::::new(b"hyperplonk"); + + prover_sanity_check(&pk.params, pub_input, witnesses)?; + + // witness assignment of length 2^n + let num_vars = pk.params.num_variables(); + let log_num_witness_polys = log2(pk.params.num_witness_columns()) as usize; + let log_num_selector_polys = log2(pk.params.num_selector_columns()) as usize; + // number of variables in merged polynomial for Multilinear-KZG + let merged_nv = num_vars + log_num_witness_polys; + // online public input of length 2^\ell + let ell = log2(pk.params.num_pub_input) as usize; + + // We use accumulators to store the polynomials and their eval points. + // They are batch opened at a later stage. + // This includes + // - witnesses + // - prod(x) + // - selectors + // + // Accumulator for w_merged and its points + let mut w_merged_pcs_acc = PcsAccumulator::::new(); + // Accumulator for prod(x) and its points + let mut prod_pcs_acc = PcsAccumulator::::new(); + // Accumulator for prod(x) and its points + let mut selector_pcs_acc = PcsAccumulator::::new(); + + let witness_polys: Vec>> = witnesses + .iter() + .map(|w| Rc::new(DenseMultilinearExtension::from(w))) + .collect(); + + // ======================================================================= + // 1. Commit Witness polynomials `w_i(x)` and append commitment to + // transcript + // ======================================================================= + let step = start_timer!(|| "commit witnesses"); + let w_merged = merge_polynomials(&witness_polys)?; + if w_merged.num_vars != merged_nv { + return Err(HyperPlonkErrors::InvalidParameters(format!( + "merged witness poly has a different num_vars ({}) from expected ({})", + w_merged.num_vars, merged_nv + ))); + } + let w_merged_com = PCS::commit(&pk.pcs_param, &w_merged)?; + w_merged_pcs_acc.init_poly(w_merged.clone(), w_merged_com.clone())?; + transcript.append_serializable_element(b"w", &w_merged_com)?; + end_timer!(step); + // ======================================================================= + // 2 Run ZeroCheck on + // + // `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` + // + // where `f` is the constraint polynomial i.e., + // + // f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) + // = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) + // + // in vanilla plonk, and obtain a ZeroCheckSubClaim + // ======================================================================= + let step = start_timer!(|| "ZeroCheck on f"); + + let fx = build_f( + &pk.params.gate_func, + pk.params.num_variables(), + &pk.selector_oracles, + &witness_polys, + )?; + + let zero_check_proof = >::prove(&fx, &mut transcript)?; + end_timer!(step); + // ======================================================================= + // 3. Run permutation check on `\{w_i(x)\}` and `permutation_oracle`, and + // obtain a PermCheckSubClaim. + // ======================================================================= + let step = start_timer!(|| "Permutation check on w_i(x)"); + + let (perm_check_proof, prod_x) = >::prove( + &pk.pcs_param, + &w_merged, + &w_merged, + &pk.permutation_oracle, + &mut transcript, + )?; + let perm_check_point = &perm_check_proof.zero_check_proof.point; + + end_timer!(step); + // ======================================================================= + // 4. Generate evaluations and corresponding proofs + // - 4.1. (deferred) batch opening prod(x) at + // - [0, perm_check_point] + // - [1, perm_check_point] + // - [perm_check_point, 0] + // - [perm_check_point, 1] + // - [1,...1, 0] + // + // - 4.2. permutation check evaluations and proofs + // - 4.2.1. (deferred) wi_poly(perm_check_point) + // + // - 4.3. zero check evaluations and proofs + // - 4.3.1. (deferred) wi_poly(zero_check_point) + // - 4.3.2. (deferred) selector_poly(zero_check_point) + // + // - 4.4. (deferred) public input consistency checks + // - pi_poly(r_pi) where r_pi is sampled from transcript + // ======================================================================= + let step = start_timer!(|| "opening and evaluations"); + + // 4.1 (deferred) open prod(0,x), prod(1, x), prod(x, 0), prod(x, 1) + // perm_check_point + prod_pcs_acc.init_poly(prod_x, perm_check_proof.prod_x_comm.clone())?; + // prod(0, x) + let tmp_point1 = [perm_check_point.as_slice(), &[E::Fr::zero()]].concat(); + // prod(1, x) + let tmp_point2 = [perm_check_point.as_slice(), &[E::Fr::one()]].concat(); + // prod(x, 0) + let tmp_point3 = [&[E::Fr::zero()], perm_check_point.as_slice()].concat(); + // prod(x, 1) + let tmp_point4 = [&[E::Fr::one()], perm_check_point.as_slice()].concat(); + // prod(1, ..., 1, 0) + let tmp_point5 = [vec![E::Fr::zero()], vec![E::Fr::one(); merged_nv]].concat(); + + prod_pcs_acc.insert_point(&tmp_point1); + prod_pcs_acc.insert_point(&tmp_point2); + prod_pcs_acc.insert_point(&tmp_point3); + prod_pcs_acc.insert_point(&tmp_point4); + prod_pcs_acc.insert_point(&tmp_point5); + + // 4.2 permutation check + // - 4.2.1. (deferred) wi_poly(perm_check_point) + w_merged_pcs_acc.insert_point(perm_check_point); + + #[cfg(feature = "extensive_sanity_checks")] + { + // sanity check + let eval = pk + .permutation_oracle + .evaluate(&perm_check_proof.zero_check_proof.point) + .ok_or_else(|| { + HyperPlonkErrors::InvalidParameters( + "perm_oracle evaluation dimension does not match".to_string(), + ) + })?; + if eval != perm_oracle_eval { + return Err(HyperPlonkErrors::InvalidProver( + "perm_oracle evaluation is different from PCS opening".to_string(), + )); + } + } + + // - 4.3. zero check evaluations and proofs + // - 4.3.1 (deferred) wi_poly(zero_check_point) + for i in 0..witness_polys.len() { + let tmp_point = gen_eval_point(i, log_num_witness_polys, &zero_check_proof.point); + // Deferred opening zero check proof + w_merged_pcs_acc.insert_point(&tmp_point); + } + + // - 4.3.2. (deferred) selector_poly(zero_check_point) + let selector_merged = merge_polynomials(&pk.selector_oracles)?; + selector_pcs_acc.init_poly(selector_merged, pk.selector_com.clone())?; + for i in 0..pk.selector_oracles.len() { + let tmp_point = gen_eval_point(i, log_num_selector_polys, &zero_check_proof.point); + // Deferred opening zero check proof + selector_pcs_acc.insert_point(&tmp_point); + } + + // - 4.4. public input consistency checks + // - pi_poly(r_pi) where r_pi is sampled from transcript + let r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?; + let tmp_point = [ + vec![E::Fr::zero(); num_vars - ell], + r_pi, + vec![E::Fr::zero(); log_num_witness_polys], + ] + .concat(); + w_merged_pcs_acc.insert_point(&tmp_point); + + #[cfg(feature = "extensive_sanity_checks")] + { + // sanity check + let pi_poly = Rc::new(DenseMultilinearExtension::from_evaluations_slice( + ell, pub_input, + )); + + let eval = pi_poly.evaluate(&r_pi).ok_or_else(|| { + HyperPlonkErrors::InvalidParameters( + "public input evaluation dimension does not match".to_string(), + ) + })?; + if eval != pi_eval { + return Err(HyperPlonkErrors::InvalidProver( + "public input evaluation is different from PCS opening".to_string(), + )); + } + } + end_timer!(step); + + // ======================================================================= + // 5. deferred batch opening + // ======================================================================= + let step = start_timer!(|| "deferred batch openings"); + let sub_step = start_timer!(|| "open witness"); + let (w_merged_batch_opening, w_merged_batch_evals) = + w_merged_pcs_acc.batch_open(&pk.pcs_param)?; + end_timer!(sub_step); + + let sub_step = start_timer!(|| "open prod(x)"); + let (prod_batch_openings, prod_batch_evals) = prod_pcs_acc.batch_open(&pk.pcs_param)?; + end_timer!(sub_step); + + let sub_step = start_timer!(|| "open selector"); + let (selector_batch_opening, selector_batch_evals) = + selector_pcs_acc.batch_open(&pk.pcs_param)?; + end_timer!(sub_step); + end_timer!(step); + end_timer!(start); + + Ok(HyperPlonkProof { + // ======================================================================= + // witness related + // ======================================================================= + /// PCS commit for witnesses + w_merged_com, + // Batch opening for witness commitment + // - PermCheck eval: 1 point + // - ZeroCheck evals: #witness points + // - public input eval: 1 point + w_merged_batch_opening, + // Evaluations of Witness + // - PermCheck eval: 1 point + // - ZeroCheck evals: #witness points + // - public input eval: 1 point + w_merged_batch_evals, + // ======================================================================= + // prod(x) related + // ======================================================================= + // prod(x)'s openings + // - prod(0, x), + // - prod(1, x), + // - prod(x, 0), + // - prod(x, 1), + // - prod(1, ..., 1,0) + prod_batch_openings, + // prod(x)'s evaluations + // - prod(0, x), + // - prod(1, x), + // - prod(x, 0), + // - prod(x, 1), + // - prod(1, ..., 1,0) + prod_batch_evals, + // ======================================================================= + // selectors related + // ======================================================================= + // PCS openings for selectors on zero check point + selector_batch_opening, + // Evaluates of selectors on zero check point + selector_batch_evals, + // ======================================================================= + // IOP proofs + // ======================================================================= + // the custom gate zerocheck proof + zero_check_proof, + // the permutation check proof for copy constraints + perm_check_proof, + }) + } + + /// Verify the HyperPlonk proof. + /// + /// Inputs: + /// - `vk`: verification key + /// - `pub_input`: online public input + /// - `proof`: HyperPlonk SNARK proof + /// Outputs: + /// - Return a boolean on whether the verification is successful + /// + /// 1. Verify zero_check_proof on + /// + /// `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` + /// + /// where `f` is the constraint polynomial i.e., + /// ```ignore + /// f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) + /// = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) + /// ``` + /// in vanilla plonk, and obtain a ZeroCheckSubClaim + /// + /// 2. Verify perm_check_proof on `\{w_i(x)\}` and `permutation_oracle` + /// + /// 3. check subclaim validity + /// + /// 4. Verify the opening against the commitment: + /// - check permutation check evaluations + /// - check zero check evaluations + /// - public input consistency checks + fn verify( + vk: &Self::VerifyingKey, + pub_input: &[E::Fr], + proof: &Self::Proof, + ) -> Result { + let start = start_timer!(|| "hyperplonk verification"); + + let mut transcript = IOPTranscript::::new(b"hyperplonk"); + // witness assignment of length 2^n + let num_vars = vk.params.num_variables(); + let log_num_witness_polys = log2(vk.params.num_witness_columns()) as usize; + // number of variables in merged polynomial for Multilinear-KZG + let merged_nv = num_vars + log_num_witness_polys; + + // online public input of length 2^\ell + let ell = log2(vk.params.num_pub_input) as usize; + + let pi_poly = DenseMultilinearExtension::from_evaluations_slice(ell as usize, pub_input); + + // ======================================================================= + // 0. sanity checks + // ======================================================================= + // public input length + if pub_input.len() != vk.params.num_pub_input { + return Err(HyperPlonkErrors::InvalidProver(format!( + "Public input length is not correct: got {}, expect {}", + pub_input.len(), + 1 << ell + ))); + } + if proof.selector_batch_evals.len() - 1 != vk.params.num_selector_columns() { + return Err(HyperPlonkErrors::InvalidVerifier(format!( + "Selector length is not correct: got {}, expect {}", + proof.selector_batch_evals.len() - 1, + 1 << vk.params.num_selector_columns() + ))); + } + if proof.w_merged_batch_evals.len() != vk.params.num_witness_columns() + 3 { + return Err(HyperPlonkErrors::InvalidVerifier(format!( + "Witness length is not correct: got {}, expect {}", + proof.w_merged_batch_evals.len() - 3, + vk.params.num_witness_columns() + ))); + } + if proof.prod_batch_evals.len() - 1 != 5 { + return Err(HyperPlonkErrors::InvalidVerifier(format!( + "the number of product polynomial evaluations is not correct: got {}, expect {}", + proof.prod_batch_evals.len() - 1, + 5 + ))); + } + + // ======================================================================= + // 1. Verify zero_check_proof on + // `f(q_0(x),...q_l(x), w_0(x),...w_d(x))` + // + // where `f` is the constraint polynomial i.e., + // + // f(q_l, q_r, q_m, q_o, w_a, w_b, w_c) + // = q_l w_a(x) + q_r w_b(x) + q_m w_a(x)w_b(x) - q_o w_c(x) + // + // ======================================================================= + let step = start_timer!(|| "verify zero check"); + // Zero check and perm check have different AuxInfo + let zero_check_aux_info = VPAuxInfo:: { + max_degree: vk.params.gate_func.degree(), + num_variables: num_vars, + phantom: PhantomData::default(), + }; + // push witness to transcript + transcript.append_serializable_element(b"w", &proof.w_merged_com)?; + + let zero_check_sub_claim = >::verify( + &proof.zero_check_proof, + &zero_check_aux_info, + &mut transcript, + )?; + + let zero_check_point = &zero_check_sub_claim.point; + + // check zero check subclaim + let f_eval = eval_f( + &vk.params.gate_func, + &proof.selector_batch_evals[..vk.params.num_selector_columns()], + &proof.w_merged_batch_evals[1..], + )?; + if f_eval != zero_check_sub_claim.expected_evaluation { + return Err(HyperPlonkErrors::InvalidProof( + "zero check evaluation failed".to_string(), + )); + } + + end_timer!(step); + // ======================================================================= + // 2. Verify perm_check_proof on `\{w_i(x)\}` and `permutation_oracle` + // ======================================================================= + let step = start_timer!(|| "verify permutation check"); + + // Zero check and perm check have different AuxInfo + let perm_check_aux_info = VPAuxInfo:: { + // Prod(x) has a max degree of 2 + max_degree: 2, + // degree of merged poly + num_variables: merged_nv, + phantom: PhantomData::default(), + }; + let perm_check_sub_claim = >::verify( + &proof.perm_check_proof, + &perm_check_aux_info, + &mut transcript, + )?; + + let perm_check_point = &perm_check_sub_claim + .product_check_sub_claim + .zero_check_sub_claim + .point; + + let alpha = perm_check_sub_claim.product_check_sub_claim.alpha; + let (beta, gamma) = perm_check_sub_claim.challenges; + + // check perm check subclaim: + // proof.witness_perm_check_eval ?= perm_check_sub_claim.expected_eval + // + // Q(x) := prod(1,x) - prod(x, 0) * prod(x, 1) + // + alpha * ( + // (g(x) + beta * s_perm(x) + gamma) * prod(0, x) + // - (f(x) + beta * s_id(x) + gamma)) + // where + // - Q(x) is perm_check_sub_claim.zero_check.exp_eval + // - prod(1, x) ... from prod(x) evaluated over (1, zero_point) + // - g(x), f(x) are both w_merged over (zero_point) + // - s_perm(x) and s_id(x) from vk_param.perm_oracle + // - alpha, beta, gamma from challenge + + // we evaluate MLE directly instead of using s_id/s_perm PCS verify + // Verification takes n pairings while evaluate takes 2^n field ops. + let s_id = identity_permutation_mle::(perm_check_point.len()); + let s_id_eval = evaluate_opt(&s_id, perm_check_point); + let s_perm_eval = evaluate_opt(&vk.permutation_oracle, perm_check_point); + + let q_x_rec = proof.prod_batch_evals[1] + - proof.prod_batch_evals[2] * proof.prod_batch_evals[3] + + alpha + * ((proof.w_merged_batch_evals[0] + beta * s_perm_eval + gamma) + * proof.prod_batch_evals[0] + - (proof.w_merged_batch_evals[0] + beta * s_id_eval + gamma)); + + if q_x_rec + != perm_check_sub_claim + .product_check_sub_claim + .zero_check_sub_claim + .expected_evaluation + { + return Err(HyperPlonkErrors::InvalidVerifier( + "evaluation failed".to_string(), + )); + } + + end_timer!(step); + // ======================================================================= + // 3. Verify the opening against the commitment + // ======================================================================= + let step = start_timer!(|| "verify commitments"); + + // ======================================================================= + // 3.1 open prod(x)' evaluations + // ======================================================================= + let prod_final_query = perm_check_sub_claim.product_check_sub_claim.final_query; + let points = [ + [perm_check_point.as_slice(), &[E::Fr::zero()]].concat(), + [perm_check_point.as_slice(), &[E::Fr::one()]].concat(), + [&[E::Fr::zero()], perm_check_point.as_slice()].concat(), + [&[E::Fr::one()], perm_check_point.as_slice()].concat(), + prod_final_query.0, + ]; + + if !PCS::batch_verify_single_poly( + &vk.pcs_param, + &proof.perm_check_proof.prod_x_comm, + &points, + &proof.prod_batch_evals, + &proof.prod_batch_openings, + )? { + return Err(HyperPlonkErrors::InvalidProof( + "prod(0, x) pcs verification failed".to_string(), + )); + } + + // ======================================================================= + // 3.2 open selectors' evaluations + // ======================================================================= + let log_num_selector_polys = log2(vk.params.num_selector_columns()) as usize; + let mut points = vec![]; + for i in 0..vk.params.num_selector_columns() { + let tmp_point = + gen_eval_point(i, log_num_selector_polys, &proof.zero_check_proof.point); + points.push(tmp_point); + } + + if !PCS::batch_verify_single_poly( + &vk.pcs_param, + &vk.selector_com, + &points, + &proof.selector_batch_evals, + &proof.selector_batch_opening, + )? { + return Err(HyperPlonkErrors::InvalidProof( + "selector pcs verification failed".to_string(), + )); + } + + // ======================================================================= + // 3.2 open witnesses' evaluations + // ======================================================================= + let mut r_pi = transcript.get_and_append_challenge_vectors(b"r_pi", ell)?; + let pi_eval = evaluate_opt(&pi_poly, &r_pi); + assert_eq!( + pi_eval, + proof.w_merged_batch_evals[proof.w_merged_batch_evals.len() - 2] + ); + + r_pi = [ + vec![E::Fr::zero(); num_vars - ell], + r_pi, + vec![E::Fr::zero(); log_num_witness_polys], + ] + .concat(); + + let mut points = vec![perm_check_point.clone()]; + + for i in 0..proof.w_merged_batch_evals.len() - 3 { + points.push(gen_eval_point(i, log_num_witness_polys, zero_check_point)) + } + points.push(r_pi); + if !PCS::batch_verify_single_poly( + &vk.pcs_param, + &proof.w_merged_com, + &points, + &proof.w_merged_batch_evals, + &proof.w_merged_batch_opening, + )? { + return Err(HyperPlonkErrors::InvalidProof( + "witness for permutation check pcs verification failed".to_string(), + )); + } + + end_timer!(step); + end_timer!(start); + Ok(true) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + custom_gate::CustomizedGates, selectors::SelectorColumn, structs::HyperPlonkParams, + witness::WitnessColumn, + }; + use arithmetic::random_permutation_mle; + use ark_bls12_381::Bls12_381; + use ark_std::test_rng; + use pcs::prelude::MultilinearKzgPCS; + + #[test] + fn test_hyperplonk_e2e() -> Result<(), HyperPlonkErrors> { + // Example: + // q_L(X) * W_1(X)^5 - W_2(X) = 0 + // is represented as + // vec![ + // ( 1, Some(id_qL), vec![id_W1, id_W1, id_W1, id_W1, id_W1]), + // (-1, None, vec![id_W2]) + // ] + // + // 4 public input + // 1 selector, + // 2 witnesses, + // 2 variables for MLE, + // 4 wires, + let gates = CustomizedGates { + gates: vec![(1, Some(0), vec![0, 0, 0, 0, 0]), (-1, None, vec![1])], + }; + test_hyperplonk_helper::(gates) + } + + fn test_hyperplonk_helper( + gate_func: CustomizedGates, + ) -> Result<(), HyperPlonkErrors> { + let mut rng = test_rng(); + let pcs_srs = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 16)?; + + let num_constraints = 4; + let num_pub_input = 4; + let nv = log2(num_constraints) as usize; + let merged_nv = nv + log2(gate_func.num_witness_columns()) as usize; + + // generate index + let params = HyperPlonkParams { + num_constraints, + num_pub_input, + gate_func, + }; + let permutation = identity_permutation_mle(merged_nv).evaluations.clone(); + let q1 = SelectorColumn(vec![E::Fr::one(), E::Fr::one(), E::Fr::one(), E::Fr::one()]); + let index = HyperPlonkIndex { + params, + permutation, + selectors: vec![q1], + }; + + // generate pk and vks + let (pk, vk) = as HyperPlonkSNARK>>::preprocess( + &index, &pcs_srs, + )?; + + // w1 := [0, 1, 2, 3] + let w1 = WitnessColumn(vec![ + E::Fr::zero(), + E::Fr::one(), + E::Fr::from(2u64), + E::Fr::from(3u64), + ]); + // w2 := [0^5, 1^5, 2^5, 3^5] + let w2 = WitnessColumn(vec![ + E::Fr::zero(), + E::Fr::one(), + E::Fr::from(32u64), + E::Fr::from(243u64), + ]); + // public input = w1 + let pi = w1.clone(); + + // generate a proof and verify + let proof = as HyperPlonkSNARK>>::prove( + &pk, + &pi.0, + &[w1.clone(), w2.clone()], + )?; + + let _verify = as HyperPlonkSNARK>>::verify( + &vk, &pi.0, &proof, + )?; + + // bad path 1: wrong permutation + let rand_perm: Vec = random_permutation_mle(merged_nv, &mut rng) + .evaluations + .clone(); + let mut bad_index = index.clone(); + bad_index.permutation = rand_perm; + // generate pk and vks + let (_, bad_vk) = as HyperPlonkSNARK>>::preprocess( + &bad_index, &pcs_srs, + )?; + assert!( + as HyperPlonkSNARK>>::verify( + &bad_vk, &pi.0, &proof, + ) + .is_err() + ); + + // bad path 2: wrong witness + let mut w1_bad = w1.clone(); + w1_bad.0[0] = E::Fr::one(); + assert!( + as HyperPlonkSNARK>>::prove( + &pk, + &pi.0, + &[w1_bad, w2], + ) + .is_err() + ); + + Ok(()) + } +} diff --git a/hyperplonk/src/structs.rs b/hyperplonk/src/structs.rs index 7b2ffb3..b63c947 100644 --- a/hyperplonk/src/structs.rs +++ b/hyperplonk/src/structs.rs @@ -1,11 +1,11 @@ //! Main module for the HyperPlonk PolyIOP. -use crate::selectors::SelectorColumn; +use crate::{custom_gate::CustomizedGates, selectors::SelectorColumn}; use ark_ec::PairingEngine; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; -use ark_std::cmp::max; -use jf_primitives::pcs::PolynomialCommitmentScheme; +use ark_std::log2; +use pcs::PolynomialCommitmentScheme; use poly_iop::prelude::{PermutationCheck, ZeroCheck}; use std::rc::Rc; @@ -22,80 +22,86 @@ where PCS: PolynomialCommitmentScheme, { // ======================================================================= - // PCS components: common + // witness related // ======================================================================= - /// PCS commit for witnesses + // PCS commit for witnesses pub w_merged_com: PCS::Commitment, + // Batch opening for witness commitment + // - PermCheck eval: 1 point + // - ZeroCheck evals: #witness points + // - public input eval: 1 point + pub w_merged_batch_opening: PCS::BatchProof, + // Evaluations of Witness + // - PermCheck eval: 1 point + // - ZeroCheck evals: #witness points + // - public input eval: 1 point + pub w_merged_batch_evals: Vec, // ======================================================================= - // PCS components: permutation check + // prod(x) related // ======================================================================= - /// prod(x)'s evaluations - /// sequence: prod(0,x), prod(1, x), prod(x, 0), prod(x, 1), prod(1, ..., 1, - /// 0) - pub prod_evals: Vec, - /// prod(x)'s openings - /// sequence: prod(0,x), prod(1, x), prod(x, 0), prod(x, 1), prod(1, ..., 1, - /// 0) - pub prod_openings: Vec, - /// PCS openings for witness on permutation check point - // TODO: replace me with a batch opening - pub witness_perm_check_opening: PCS::Proof, - /// Evaluates of witnesses on permutation check point - pub witness_perm_check_eval: E::Fr, - /// PCS openings for selectors on permutation check point - // TODO: replace me with a batch opening - pub perm_oracle_opening: PCS::Proof, - /// Evaluates of selectors on permutation check point - pub perm_oracle_eval: E::Fr, + // prod(x)'s openings + // - prod(0, x), + // - prod(1, x), + // - prod(x, 0), + // - prod(x, 1), + // - prod(1, ..., 1,0) + pub prod_batch_openings: PCS::BatchProof, + // prod(x)'s evaluations + // - prod(0, x), + // - prod(1, x), + // - prod(x, 0), + // - prod(x, 1), + // - prod(1, ..., 1,0) + pub prod_batch_evals: Vec, // ======================================================================= - // PCS components: zero check + // selectors related // ======================================================================= - /// PCS openings for witness on zero check point - // TODO: replace me with a batch opening - pub witness_zero_check_openings: Vec, - /// Evaluates of witnesses on zero check point - pub witness_zero_check_evals: Vec, - /// PCS openings for selectors on zero check point - // TODO: replace me with a batch opening - pub selector_oracle_openings: Vec, - /// Evaluates of selectors on zero check point - pub selector_oracle_evals: Vec, + // PCS openings for selectors on zero check point + pub selector_batch_opening: PCS::BatchProof, + // Evaluates of selectors on zero check point + pub selector_batch_evals: Vec, // ======================================================================= - // PCS components: public inputs + // IOP proofs // ======================================================================= - /// Evaluates of public inputs on r_pi from transcript - pub pi_eval: E::Fr, - /// Opening of public inputs on r_pi from transcript - pub pi_opening: PCS::Proof, - // ======================================================================= - // IOP components - // ======================================================================= - /// the custom gate zerocheck proof + // the custom gate zerocheck proof pub zero_check_proof: >::ZeroCheckProof, - /// the permutation check proof for copy constraints + // the permutation check proof for copy constraints pub perm_check_proof: PC::PermutationProof, } /// The HyperPlonk instance parameters, consists of the following: -/// - the number of variables in the poly-IOP -/// - binary log of the number of public input variables -/// - binary log of the number of selectors -/// - binary log of the number of witness wires +/// - the number of constraints +/// - number of public input columns /// - the customized gate function #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct HyperPlonkParams { - /// the number of variables in polys - pub nv: usize, - /// binary log of the public input length - pub log_pub_input_len: usize, - // binary log of the number of selectors - pub log_n_selectors: usize, - /// binary log of the number of witness wires - pub log_n_wires: usize, + /// the number of constraints + pub num_constraints: usize, + /// number of public input + // public input is only 1 column and is implicitly the first witness column. + // this size must not exceed number of constraints. + pub num_pub_input: usize, /// customized gate function pub gate_func: CustomizedGates, } +impl HyperPlonkParams { + /// Number of variables in a multilinear system + pub fn num_variables(&self) -> usize { + log2(self.num_constraints) as usize + } + + /// number of selector columns + pub fn num_selector_columns(&self) -> usize { + self.gate_func.num_selector_columns() + } + + /// number of witness columns + pub fn num_witness_columns(&self) -> usize { + self.gate_func.num_witness_columns() + } +} + /// The HyperPlonk index, consists of the following: /// - HyperPlonk parameters /// - the wire permutation @@ -107,72 +113,56 @@ pub struct HyperPlonkIndex { pub selectors: Vec>, } +impl HyperPlonkIndex { + /// Number of variables in a multilinear system + pub fn num_variables(&self) -> usize { + self.params.num_variables() + } + + /// number of selector columns + pub fn num_selector_columns(&self) -> usize { + self.params.num_selector_columns() + } + + /// number of witness columns + pub fn num_witness_columns(&self) -> usize { + self.params.num_witness_columns() + } +} + /// The HyperPlonk proving key, consists of the following: /// - the hyperplonk instance parameters /// - the preprocessed polynomials output by the indexer +/// - the commitment to the selectors +/// - the parameters for polynomial commitment #[derive(Clone, Debug, Default, PartialEq)] pub struct HyperPlonkProvingKey> { - /// hyperplonk instance parameters + /// Hyperplonk instance parameters pub params: HyperPlonkParams, - /// the preprocessed permutation polynomials + /// The preprocessed permutation polynomials pub permutation_oracle: Rc>, - /// the preprocessed selector polynomials - // TODO: merge the list into a single MLE + /// The preprocessed selector polynomials pub selector_oracles: Vec>>, - /// the parameters for PCS commitment + /// A commitment to the preprocessed selector polynomials + pub selector_com: PCS::Commitment, + /// The parameters for PCS commitment pub pcs_param: PCS::ProverParam, } /// The HyperPlonk verifying key, consists of the following: /// - the hyperplonk instance parameters -/// - the preprocessed polynomials output by the indexer +/// - the commitments to the preprocessed polynomials output by the indexer +/// - the parameters for polynomial commitment #[derive(Clone, Debug, Default, PartialEq)] pub struct HyperPlonkVerifyingKey> { - /// hyperplonk instance parameters + /// Hyperplonk instance parameters pub params: HyperPlonkParams, - /// the parameters for PCS commitment + /// The preprocessed permutation polynomials + pub permutation_oracle: Rc>, + /// The parameters for PCS commitment pub pcs_param: PCS::VerifierParam, - /// Selector's commitment - // TODO: replace me with a batch commitment - pub selector_com: Vec, + /// A commitment to the preprocessed selector polynomials + pub selector_com: PCS::Commitment, /// Permutation oracle's commitment pub perm_com: PCS::Commitment, } - -/// Customized gate is a list of tuples of -/// (coefficient, selector_index, wire_indices) -/// -/// Example: -/// q_L(X) * W_1(X)^5 - W_2(X) -/// is represented as -/// vec![ -/// ( 1, Some(id_qL), vec![id_W1, id_W1, id_W1, id_W1, id_W1]), -/// (-1, None, vec![id_W2]) -/// ] -/// -/// CustomizedGates { -/// gates: vec![ -/// (1, Some(0), vec![0, 0, 0, 0, 0]), -/// (-1, None, vec![1]) -/// ], -/// }; -/// where id_qL = 0 // first selector -/// id_W1 = 0 // first witness -/// id_w2 = 1 // second witness -/// -/// NOTE: here coeff is a signed integer, instead of a field element -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct CustomizedGates { - pub(crate) gates: Vec<(i64, Option, Vec)>, -} - -impl CustomizedGates { - /// The degree of the algebraic customized gate - pub fn degree(&self) -> usize { - let mut res = 0; - for x in self.gates.iter() { - res = max(res, x.2.len() + (x.1.is_some() as usize)) - } - res - } -} diff --git a/hyperplonk/src/utils.rs b/hyperplonk/src/utils.rs index c6bf8b9..a474379 100644 --- a/hyperplonk/src/utils.rs +++ b/hyperplonk/src/utils.rs @@ -1,16 +1,87 @@ -use std::rc::Rc; - +use crate::{ + custom_gate::CustomizedGates, errors::HyperPlonkErrors, structs::HyperPlonkParams, + witness::WitnessColumn, +}; use arithmetic::VirtualPolynomial; +use ark_ec::PairingEngine; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; +use pcs::PolynomialCommitmentScheme; +use std::{borrow::Borrow, rc::Rc}; -use crate::{ - errors::HyperPlonkErrors, - structs::{CustomizedGates, HyperPlonkParams}, - witness::WitnessColumn, -}; +/// An accumulator structure that holds a polynomial and +/// its opening points +#[derive(Debug)] +pub(super) struct PcsAccumulator> { + pub(crate) polynomial: Option, + pub(crate) poly_commit: Option, + pub(crate) points: Vec, +} + +impl> PcsAccumulator { + /// Create an empty accumulator. + pub(super) fn new() -> Self { + Self { + polynomial: None, + poly_commit: None, + points: vec![], + } + } + + /// Initialize the polynomial; requires both the polynomial + /// and its commitment. + pub(super) fn init_poly( + &mut self, + polynomial: PCS::Polynomial, + commitment: PCS::Commitment, + ) -> Result<(), HyperPlonkErrors> { + if self.polynomial.is_some() || self.poly_commit.is_some() { + return Err(HyperPlonkErrors::InvalidProver( + "poly already set for accumulator".to_string(), + )); + } + + self.polynomial = Some(polynomial); + self.poly_commit = Some(commitment); + Ok(()) + } + + /// Push a new evaluation point into the accumulator + pub(super) fn insert_point(&mut self, point: &PCS::Point) { + self.points.push(point.clone()) + } + + /// Batch open all the points over a merged polynomial. + /// A simple wrapper of PCS::multi_open + pub(super) fn batch_open( + &self, + prover_param: impl Borrow, + ) -> Result<(PCS::BatchProof, Vec), HyperPlonkErrors> { + let poly = match &self.polynomial { + Some(p) => p, + None => { + return Err(HyperPlonkErrors::InvalidProver( + "poly is set for accumulator".to_string(), + )) + }, + }; -use poly_iop::prelude::bit_decompose; + let commitment = match &self.poly_commit { + Some(p) => p, + None => { + return Err(HyperPlonkErrors::InvalidProver( + "poly is set for accumulator".to_string(), + )) + }, + }; + Ok(PCS::multi_open_single_poly( + prover_param.borrow(), + commitment, + poly, + &self.points, + )?) + } +} /// Build MLE from matrix of witnesses. /// @@ -42,30 +113,37 @@ macro_rules! build_mle { } /// Sanity-check for HyperPlonk SNARK proving -pub(crate) fn prove_sanity_check( +pub(crate) fn prover_sanity_check( params: &HyperPlonkParams, pub_input: &[F], witnesses: &[WitnessColumn], ) -> Result<(), HyperPlonkErrors> { - let num_vars = params.nv; - let ell = params.log_pub_input_len; + // public input length must be no greater than num_constraints + + if pub_input.len() > params.num_constraints { + return Err(HyperPlonkErrors::InvalidProver(format!( + "Public input length {} is greater than num constraits {}", + pub_input.len(), + params.num_pub_input + ))); + } // public input length - if pub_input.len() != 1 << ell { + if pub_input.len() != params.num_pub_input { return Err(HyperPlonkErrors::InvalidProver(format!( "Public input length is not correct: got {}, expect {}", pub_input.len(), - 1 << ell + params.num_pub_input ))); } // witnesses length for (i, w) in witnesses.iter().enumerate() { - if w.0.len() != 1 << num_vars { + if w.0.len() != params.num_constraints { return Err(HyperPlonkErrors::InvalidProver(format!( "{}-th witness length is not correct: got {}, expect {}", i, - pub_input.len(), - 1 << ell + w.0.len(), + params.num_constraints ))); } } @@ -161,17 +239,6 @@ pub(crate) fn eval_f( Ok(res) } -/// given the evaluation input `point` of the `index`-th polynomial, -/// obtain the evaluation point in the merged polynomial -pub(crate) fn gen_eval_point(index: usize, index_len: usize, point: &[F]) -> Vec { - let mut index_vec: Vec = bit_decompose(index as u64, index_len) - .into_iter() - .map(|x| F::from(x)) - .collect(); - index_vec.reverse(); - [point, &index_vec].concat() -} - #[cfg(test)] mod test { use super::*; diff --git a/hyperplonk/src/witness.rs b/hyperplonk/src/witness.rs index 60145ea..e9aaa78 100644 --- a/hyperplonk/src/witness.rs +++ b/hyperplonk/src/witness.rs @@ -9,7 +9,7 @@ use std::rc::Rc; pub struct WitnessRow(pub(crate) Vec); /// A column of witnesses of length `#constraints` -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct WitnessColumn(pub(crate) Vec); impl WitnessColumn { @@ -19,6 +19,11 @@ impl WitnessColumn { log2(self.0.len()) as usize } + /// Append a new element to the witness column + pub fn append(&mut self, new_element: F) { + self.0.push(new_element) + } + /// Build witness columns from rows pub fn from_witness_rows( witness_rows: &[WitnessRow], @@ -42,6 +47,10 @@ impl WitnessColumn { Ok(res) } + + pub fn coeff_ref(&self) -> &[F] { + self.0.as_ref() + } } impl From<&WitnessColumn> for DenseMultilinearExtension { diff --git a/pcs/Cargo.toml b/pcs/Cargo.toml new file mode 100644 index 0000000..9e2813d --- /dev/null +++ b/pcs/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "pcs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] + +ark-std = { version = "^0.3.0", default-features = false } +ark-serialize = { version = "^0.3.0", default-features = false, features = [ "derive" ] } +ark-ff = { version = "^0.3.0", default-features = false } +ark-ec = { version = "^0.3.0", default-features = false } +ark-poly = {version = "^0.3.0", default-features = false } +ark-sponge = {version = "^0.3.0", default-features = false} +ark-bls12-381 = { version = "0.3.0", default-features = false, features = [ "curve" ] } + +displaydoc = { version = "0.2.3", default-features = false } +derivative = { version = "2", features = ["use_core"] } + +arithmetic = { path = "../arithmetic" } +transcript = { path = "../transcript" } +util = { path = "../util" } + +rayon = { version = "1.5.2", default-features = false, optional = true } +itertools = { version = "0.10.4", optional = true } + +# Benchmarks +[[bench]] +name = "pcs-benches" +path = "benches/bench.rs" +harness = false + +[features] +# default = [ "parallel", "print-trace" ] +default = [ "parallel",] +extensive_sanity_checks = [ ] +parallel = [ + "rayon", + "itertools", + "ark-std/parallel", + "ark-ff/parallel", + "ark-poly/parallel", + "ark-ec/parallel", + "util/parallel", + "arithmetic/parallel", + ] +print-trace = [ + "ark-std/print-trace", + "arithmetic/print-trace", + ] \ No newline at end of file diff --git a/pcs/benches/bench.rs b/pcs/benches/bench.rs new file mode 100644 index 0000000..c9e40b0 --- /dev/null +++ b/pcs/benches/bench.rs @@ -0,0 +1,88 @@ +use ark_bls12_381::{Bls12_381, Fr}; +use ark_ff::UniformRand; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_std::{rc::Rc, test_rng}; +use pcs::{ + prelude::{MultilinearKzgPCS, PCSError, PolynomialCommitmentScheme}, + StructuredReferenceString, +}; +use std::time::Instant; + +fn main() -> Result<(), PCSError> { + bench_pcs() +} + +fn bench_pcs() -> Result<(), PCSError> { + let mut rng = test_rng(); + + // normal polynomials + let uni_params = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 18)?; + + for nv in 4..19 { + let repetition = if nv < 10 { + 100 + } else if nv < 20 { + 50 + } else { + 10 + }; + + let poly = Rc::new(DenseMultilinearExtension::rand(nv, &mut rng)); + let (ml_ck, ml_vk) = uni_params.0.trim(nv)?; + let (uni_ck, uni_vk) = uni_params.1.trim(nv)?; + let ck = (ml_ck, uni_ck); + let vk = (ml_vk, uni_vk); + + let point: Vec<_> = (0..nv).map(|_| Fr::rand(&mut rng)).collect(); + + // commit + let com = { + let start = Instant::now(); + for _ in 0..repetition { + let _commit = MultilinearKzgPCS::commit(&ck, &poly)?; + } + + println!( + "KZG commit for {} variables: {} ns", + nv, + start.elapsed().as_nanos() / repetition as u128 + ); + + MultilinearKzgPCS::commit(&ck, &poly)? + }; + + // open + let (proof, value) = { + let start = Instant::now(); + for _ in 0..repetition { + let _open = MultilinearKzgPCS::open(&ck, &poly, &point)?; + } + + println!( + "KZG open for {} variables: {} ns", + nv, + start.elapsed().as_nanos() / repetition as u128 + ); + MultilinearKzgPCS::open(&ck, &poly, &point)? + }; + + // verify + { + let start = Instant::now(); + for _ in 0..repetition { + assert!(MultilinearKzgPCS::verify( + &vk, &com, &point, &value, &proof + )?); + } + println!( + "KZG verify for {} variables: {} ns", + nv, + start.elapsed().as_nanos() / repetition as u128 + ); + } + + println!("===================================="); + } + + Ok(()) +} diff --git a/pcs/readme.md b/pcs/readme.md new file mode 100644 index 0000000..be45aaa --- /dev/null +++ b/pcs/readme.md @@ -0,0 +1,7 @@ +KZG based multilinear polynomial commitment +----- + +# Compiling features: +- `parallel`: use multi-threading when possible. +- `print-trace`: print out user friendly information about the running time for each micro component. +- `extensive_sanity_checks`: runs additional sanity checks that is not essential and will slow down the scheme. \ No newline at end of file diff --git a/pcs/src/errors.rs b/pcs/src/errors.rs new file mode 100644 index 0000000..8fcc24b --- /dev/null +++ b/pcs/src/errors.rs @@ -0,0 +1,50 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Error module. + +use arithmetic::ArithErrors; +use ark_serialize::SerializationError; +use ark_std::string::String; +use displaydoc::Display; +use transcript::TranscriptError; + +/// A `enum` specifying the possible failure modes of the PCS. +#[derive(Display, Debug)] +pub enum PCSError { + /// Invalid Prover: {0} + InvalidProver(String), + /// Invalid Verifier: {0} + InvalidVerifier(String), + /// Invalid Proof: {0} + InvalidProof(String), + /// Invalid parameters: {0} + InvalidParameters(String), + /// An error during (de)serialization: {0} + SerializationError(SerializationError), + /// Transcript error {0} + TranscriptError(TranscriptError), + /// ArithErrors error {0} + ArithErrors(ArithErrors), +} + +impl From for PCSError { + fn from(e: ark_serialize::SerializationError) -> Self { + Self::SerializationError(e) + } +} + +impl From for PCSError { + fn from(e: TranscriptError) -> Self { + Self::TranscriptError(e) + } +} + +impl From for PCSError { + fn from(e: ArithErrors) -> Self { + Self::ArithErrors(e) + } +} diff --git a/pcs/src/lib.rs b/pcs/src/lib.rs new file mode 100644 index 0000000..4331616 --- /dev/null +++ b/pcs/src/lib.rs @@ -0,0 +1,190 @@ +mod errors; +mod multilinear_kzg; +mod structs; +mod univariate_kzg; + +pub mod prelude; + +use ark_ec::PairingEngine; +use ark_ff::Field; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_std::rand::{CryptoRng, RngCore}; +use errors::PCSError; +use std::{borrow::Borrow, fmt::Debug, hash::Hash}; + +/// This trait defines APIs for polynomial commitment schemes. +/// Note that for our usage of PCS, we do not require the hiding property. +pub trait PolynomialCommitmentScheme { + /// Prover parameters + type ProverParam: Clone; + /// Verifier parameters + type VerifierParam: Clone + CanonicalSerialize + CanonicalDeserialize; + /// Structured reference string + type SRS: Clone + Debug; + /// Polynomial and its associated types + type Polynomial: Clone + + Debug + + Hash + + PartialEq + + Eq + + CanonicalSerialize + + CanonicalDeserialize; + /// Polynomial input domain + type Point: Clone + Ord + Debug + Sync + Hash + PartialEq + Eq; + /// Polynomial Evaluation + type Evaluation: Field; + /// Commitments + type Commitment: Clone + CanonicalSerialize + CanonicalDeserialize + Debug + PartialEq + Eq; + /// Batch commitments + type BatchCommitment: Clone + CanonicalSerialize + CanonicalDeserialize + Debug + PartialEq + Eq; + /// Proofs + type Proof: Clone + CanonicalSerialize + CanonicalDeserialize + Debug + PartialEq + Eq; + /// Batch proofs + type BatchProof: Clone + CanonicalSerialize + CanonicalDeserialize + Debug + PartialEq + Eq; + + /// Build SRS for testing. + /// + /// - For univariate polynomials, `supported_size` is the maximum degree. + /// - For multilinear polynomials, `supported_size` is the number of + /// variables. + /// + /// WARNING: THIS FUNCTION IS FOR TESTING PURPOSE ONLY. + /// THE OUTPUT SRS SHOULD NOT BE USED IN PRODUCTION. + fn gen_srs_for_testing( + rng: &mut R, + supported_size: usize, + ) -> Result; + + /// Trim the universal parameters to specialize the public parameters. + /// Input both `supported_degree` for univariate and + /// `supported_num_vars` for multilinear. + /// ## Note on function signature + /// Usually, data structure like SRS and ProverParam are huge and users + /// might wish to keep them in heap using different kinds of smart pointers + /// (instead of only in stack) therefore our `impl Borrow<_>` interface + /// allows for passing in any pointer type, e.g.: `trim(srs: &Self::SRS, + /// ..)` or `trim(srs: Box, ..)` or `trim(srs: Arc, + /// ..)` etc. + fn trim( + srs: impl Borrow, + supported_degree: usize, + supported_num_vars: Option, + ) -> Result<(Self::ProverParam, Self::VerifierParam), PCSError>; + + /// Generate a commitment for a polynomial + /// ## Note on function signature + /// Usually, data structure like SRS and ProverParam are huge and users + /// might wish to keep them in heap using different kinds of smart pointers + /// (instead of only in stack) therefore our `impl Borrow<_>` interface + /// allows for passing in any pointer type, e.g.: `commit(prover_param: + /// &Self::ProverParam, ..)` or `commit(prover_param: + /// Box, ..)` or `commit(prover_param: + /// Arc, ..)` etc. + fn commit( + prover_param: impl Borrow, + poly: &Self::Polynomial, + ) -> Result; + + /// Generate a commitment for a list of polynomials + fn multi_commit( + prover_param: impl Borrow, + polys: &[Self::Polynomial], + ) -> Result; + + /// On input a polynomial `p` and a point `point`, outputs a proof for the + /// same. + fn open( + prover_param: impl Borrow, + polynomial: &Self::Polynomial, + point: &Self::Point, + ) -> Result<(Self::Proof, Self::Evaluation), PCSError>; + + /// Input a list of multilinear extensions, and a same number of points, and + /// a transcript, compute a multi-opening for all the polynomials. + fn multi_open( + prover_param: impl Borrow, + multi_commitment: &Self::BatchCommitment, + polynomials: &[Self::Polynomial], + points: &[Self::Point], + ) -> Result<(Self::BatchProof, Vec), PCSError>; + + /// Input a multilinear extension, and a number of points, and + /// a transcript, compute a multi-opening for all the polynomials. + fn multi_open_single_poly( + prover_param: impl Borrow, + commitment: &Self::Commitment, + polynomials: &Self::Polynomial, + points: &[Self::Point], + ) -> Result<(Self::BatchProof, Vec), PCSError>; + + /// Verifies that `value` is the evaluation at `x` of the polynomial + /// committed inside `comm`. + fn verify( + verifier_param: &Self::VerifierParam, + commitment: &Self::Commitment, + point: &Self::Point, + value: &E::Fr, + proof: &Self::Proof, + ) -> Result; + + /// Verifies that `value_i` is the evaluation at `x_i` of the polynomial + /// `poly_i` committed inside `comm`. + fn batch_verify( + verifier_param: &Self::VerifierParam, + multi_commitment: &Self::BatchCommitment, + points: &[Self::Point], + values: &[E::Fr], + batch_proof: &Self::BatchProof, + rng: &mut R, + ) -> Result; + + /// Verifies that `value_i` is the evaluation at `x_i` of the polynomial + /// `poly` committed inside `comm`. + fn batch_verify_single_poly( + verifier_param: &Self::VerifierParam, + commitment: &Self::Commitment, + points: &[Self::Point], + values: &[E::Fr], + batch_proof: &Self::BatchProof, + ) -> Result; +} + +/// API definitions for structured reference string +pub trait StructuredReferenceString: Sized { + /// Prover parameters + type ProverParam; + /// Verifier parameters + type VerifierParam; + + /// Extract the prover parameters from the public parameters. + fn extract_prover_param(&self, supported_size: usize) -> Self::ProverParam; + /// Extract the verifier parameters from the public parameters. + fn extract_verifier_param(&self, supported_size: usize) -> Self::VerifierParam; + + /// Trim the universal parameters to specialize the public parameters + /// for polynomials to the given `supported_size`, and + /// returns committer key and verifier key. + /// + /// - For univariate polynomials, `supported_size` is the maximum degree. + /// - For multilinear polynomials, `supported_size` is 2 to the number of + /// variables. + /// + /// `supported_log_size` should be in range `1..=params.log_size` + fn trim( + &self, + supported_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), PCSError>; + + /// Build SRS for testing. + /// + /// - For univariate polynomials, `supported_size` is the maximum degree. + /// - For multilinear polynomials, `supported_size` is the number of + /// variables. + /// + /// WARNING: THIS FUNCTION IS FOR TESTING PURPOSE ONLY. + /// THE OUTPUT SRS SHOULD NOT BE USED IN PRODUCTION. + fn gen_srs_for_testing( + rng: &mut R, + supported_size: usize, + ) -> Result; +} diff --git a/pcs/src/multilinear_kzg/batching/mod.rs b/pcs/src/multilinear_kzg/batching/mod.rs new file mode 100644 index 0000000..08a5070 --- /dev/null +++ b/pcs/src/multilinear_kzg/batching/mod.rs @@ -0,0 +1,11 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +mod multi_poly; +mod single_poly; + +pub(crate) use multi_poly::*; +pub(crate) use single_poly::*; diff --git a/pcs/src/multilinear_kzg/batching/multi_poly.rs b/pcs/src/multilinear_kzg/batching/multi_poly.rs new file mode 100644 index 0000000..b3fd1bb --- /dev/null +++ b/pcs/src/multilinear_kzg/batching/multi_poly.rs @@ -0,0 +1,426 @@ +use crate::{ + multilinear_kzg::{ + open_internal, + srs::{MultilinearProverParam, MultilinearVerifierParam}, + util::compute_w_circ_l, + verify_internal, MultilinearKzgBatchProof, + }, + prelude::{Commitment, UnivariateProverParam, UnivariateVerifierParam}, + univariate_kzg::UnivariateKzgPCS, + PCSError, PolynomialCommitmentScheme, +}; +use arithmetic::{build_l, get_uni_domain, merge_polynomials}; +use ark_ec::PairingEngine; +use ark_poly::{DenseMultilinearExtension, EvaluationDomain, MultilinearExtension, Polynomial}; +use ark_std::{end_timer, format, rc::Rc, start_timer, string::ToString, vec, vec::Vec}; +use transcript::IOPTranscript; + +/// Input +/// - the prover parameters for univariate KZG, +/// - the prover parameters for multilinear KZG, +/// - a list of MLEs, +/// - a commitment to all MLEs +/// - and a same number of points, +/// compute a multi-opening for all the polynomials. +/// +/// For simplicity, this API requires each MLE to have only one point. If +/// the caller wish to use more than one points per MLE, it should be +/// handled at the caller layer, and utilize 'multi_open_same_poly_internal' +/// API. +/// +/// Returns an error if the lengths do not match. +/// +/// Returns the proof, consists of +/// - the multilinear KZG opening +/// - the univariate KZG commitment to q(x) +/// - the openings and evaluations of q(x) at omega^i and r +/// +/// Steps: +/// 1. build `l(points)` which is a list of univariate polynomials that goes +/// through the points +/// 2. build MLE `w` which is the merge of all MLEs. +/// 3. build `q(x)` which is a univariate polynomial `W circ l` +/// 4. commit to q(x) and sample r from transcript +/// transcript contains: w commitment, points, q(x)'s commitment +/// 5. build q(omega^i) and their openings +/// 6. build q(r) and its opening +/// 7. get a point `p := l(r)` +/// 8. output an opening of `w` over point `p` +/// 9. output `w(p)` +pub(crate) fn multi_open_internal( + uni_prover_param: &UnivariateProverParam, + ml_prover_param: &MultilinearProverParam, + polynomials: &[Rc>], + multi_commitment: &Commitment, + points: &[Vec], +) -> Result<(MultilinearKzgBatchProof, Vec), PCSError> { + let open_timer = start_timer!(|| "multi open"); + + // =================================== + // Sanity checks on inputs + // =================================== + let points_len = points.len(); + if points_len == 0 { + return Err(PCSError::InvalidParameters("points is empty".to_string())); + } + + if points_len != polynomials.len() { + return Err(PCSError::InvalidParameters( + "polynomial length does not match point length".to_string(), + )); + } + + let num_var = polynomials[0].num_vars(); + for poly in polynomials.iter().skip(1) { + if poly.num_vars() != num_var { + return Err(PCSError::InvalidParameters( + "polynomials do not have same num_vars".to_string(), + )); + } + } + for point in points.iter() { + if point.len() != num_var { + return Err(PCSError::InvalidParameters( + "points do not have same num_vars".to_string(), + )); + } + } + + let domain = get_uni_domain::(points_len)?; + + // 1. build `l(points)` which is a list of univariate polynomials that goes + // through the points + let uni_polys = build_l(points, &domain, true)?; + + // 2. build MLE `w` which is the merge of all MLEs. + let merge_poly = merge_polynomials(polynomials)?; + + // 3. build `q(x)` which is a univariate polynomial `W circ l` + let q_x = compute_w_circ_l(&merge_poly, &uni_polys, points.len(), true)?; + + // 4. commit to q(x) and sample r from transcript + // transcript contains: w commitment, points, q(x)'s commitment + let mut transcript = IOPTranscript::new(b"ml kzg"); + transcript.append_serializable_element(b"w", multi_commitment)?; + for point in points { + transcript.append_serializable_element(b"w", point)?; + } + + let q_x_commit = UnivariateKzgPCS::::commit(uni_prover_param, &q_x)?; + transcript.append_serializable_element(b"q(x)", &q_x_commit)?; + let r = transcript.get_and_append_challenge(b"r")?; + // 5. build q(omega^i) and their openings + let mut q_x_opens = vec![]; + let mut q_x_evals = vec![]; + for i in 0..points_len { + let (q_x_open, q_x_eval) = + UnivariateKzgPCS::::open(uni_prover_param, &q_x, &domain.element(i))?; + q_x_opens.push(q_x_open); + q_x_evals.push(q_x_eval); + #[cfg(feature = "extensive_sanity_checks")] + { + // sanity check + let point: Vec = uni_polys + .iter() + .map(|poly| poly.evaluate(&domain.element(i))) + .collect(); + let mle_eval = merge_poly.evaluate(&point).unwrap(); + if mle_eval != q_x_eval { + return Err(PCSError::InvalidProver( + "Q(omega) does not match W(l(omega))".to_string(), + )); + } + } + } + + // 6. build q(r) and its opening + let (q_x_open, q_r_value) = UnivariateKzgPCS::::open(uni_prover_param, &q_x, &r)?; + q_x_opens.push(q_x_open); + q_x_evals.push(q_r_value); + + // 7. get a point `p := l(r)` + let point: Vec = uni_polys.iter().map(|poly| poly.evaluate(&r)).collect(); + // 8. output an opening of `w` over point `p` + let (mle_opening, mle_eval) = open_internal(ml_prover_param, &merge_poly, &point)?; + + // 9. output value that is `w` evaluated at `p` (which should match `q(r)`) + if mle_eval != q_r_value { + return Err(PCSError::InvalidProver( + "Q(r) does not match W(l(r))".to_string(), + )); + } + end_timer!(open_timer); + Ok(( + MultilinearKzgBatchProof { + proof: mle_opening, + q_x_commit, + q_x_opens, + }, + q_x_evals, + )) +} + +/// Verifies that the `multi_commitment` is a valid commitment +/// to a list of MLEs for the given openings and evaluations in +/// the batch_proof. +/// +/// steps: +/// +/// 1. push w, points and q_com into transcript +/// 2. sample `r` from transcript +/// 3. check `q(r) == batch_proof.q_x_value.last` and +/// `q(omega^i) == batch_proof.q_x_value[i]` +/// 4. build `l(points)` which is a list of univariate +/// polynomials that goes through the points +/// 5. get a point `p := l(r)` +/// 6. verifies `p` is valid against multilinear KZG proof +pub(crate) fn batch_verify_internal( + uni_verifier_param: &UnivariateVerifierParam, + ml_verifier_param: &MultilinearVerifierParam, + multi_commitment: &Commitment, + points: &[Vec], + values: &[E::Fr], + batch_proof: &MultilinearKzgBatchProof, +) -> Result { + let verify_timer = start_timer!(|| "batch verify"); + + // =================================== + // Sanity checks on inputs + // =================================== + let points_len = points.len(); + if points_len == 0 { + return Err(PCSError::InvalidParameters("points is empty".to_string())); + } + + // add one here because we also have q(r) and its opening + if points_len + 1 != batch_proof.q_x_opens.len() { + return Err(PCSError::InvalidParameters(format!( + "openings length {} does not match point length {}", + points_len + 1, + batch_proof.q_x_opens.len() + ))); + } + + if points_len + 1 != values.len() { + return Err(PCSError::InvalidParameters(format!( + "values length {} does not match point length {}", + points_len + 1, + values.len() + ))); + } + + let num_var = points[0].len(); + for point in points.iter().skip(1) { + if point.len() != num_var { + return Err(PCSError::InvalidParameters(format!( + "points do not have same num_vars ({} vs {})", + point.len(), + num_var, + ))); + } + } + + let domain = get_uni_domain::(points_len)?; + // 1. push w, points and q_com into transcript + let mut transcript = IOPTranscript::new(b"ml kzg"); + transcript.append_serializable_element(b"w", multi_commitment)?; + + for point in points { + transcript.append_serializable_element(b"w", point)?; + } + + transcript.append_serializable_element(b"q(x)", &batch_proof.q_x_commit)?; + // 2. sample `r` from transcript + let r = transcript.get_and_append_challenge(b"r")?; + // 3. check `q(r) == batch_proof.q_x_value.last` and `q(omega^i) = + // batch_proof.q_x_value[i]` + for (i, value) in values.iter().enumerate().take(points_len) { + if !UnivariateKzgPCS::verify( + uni_verifier_param, + &batch_proof.q_x_commit, + &domain.element(i), + value, + &batch_proof.q_x_opens[i], + )? { + #[cfg(debug_assertion)] + println!("q(omega^{}) verification failed", i); + return Ok(false); + } + } + + if !UnivariateKzgPCS::verify( + uni_verifier_param, + &batch_proof.q_x_commit, + &r, + &values[points_len], + &batch_proof.q_x_opens[points_len], + )? { + #[cfg(debug_assertion)] + println!("q(r) verification failed"); + return Ok(false); + } + // 4. build `l(points)` which is a list of univariate polynomials that goes + // through the points + let uni_polys = build_l(points, &domain, true)?; + + // 5. get a point `p := l(r)` + let point: Vec = uni_polys.iter().map(|x| x.evaluate(&r)).collect(); + // 6. verifies `p` is valid against multilinear KZG proof + let res = verify_internal( + ml_verifier_param, + multi_commitment, + &point, + &values[points_len], + &batch_proof.proof, + )?; + #[cfg(debug_assertion)] + if !res { + println!("multilinear KZG verification failed"); + } + + end_timer!(verify_timer); + + Ok(res) +} +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + multilinear_kzg::{ + srs::MultilinearUniversalParams, + util::{compute_qx_degree, generate_evaluations_multi_poly}, + MultilinearKzgPCS, MultilinearKzgProof, + }, + prelude::UnivariateUniversalParams, + StructuredReferenceString, + }; + use arithmetic::get_batched_nv; + use ark_bls12_381::Bls12_381 as E; + use ark_ec::PairingEngine; + use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; + use ark_std::{ + log2, + rand::{CryptoRng, RngCore}, + test_rng, + vec::Vec, + UniformRand, + }; + + type Fr = ::Fr; + + fn test_multi_open_helper( + uni_params: &UnivariateUniversalParams, + ml_params: &MultilinearUniversalParams, + polys: &[Rc>], + rng: &mut R, + ) -> Result<(), PCSError> { + let merged_nv = get_batched_nv(polys[0].num_vars(), polys.len()); + let qx_degree = compute_qx_degree(merged_nv, polys.len()); + let padded_qx_degree = 1usize << log2(qx_degree); + + let (uni_ck, uni_vk) = uni_params.trim(padded_qx_degree)?; + let (ml_ck, ml_vk) = ml_params.trim(merged_nv)?; + + let mut points = Vec::new(); + for poly in polys.iter() { + let point = (0..poly.num_vars()) + .map(|_| Fr::rand(rng)) + .collect::>(); + points.push(point); + } + + let evals = generate_evaluations_multi_poly(polys, &points)?; + + let com = MultilinearKzgPCS::multi_commit(&(ml_ck.clone(), uni_ck.clone()), polys)?; + let (batch_proof, evaluations) = + multi_open_internal(&uni_ck, &ml_ck, polys, &com, &points)?; + + for (a, b) in evals.iter().zip(evaluations.iter()) { + assert_eq!(a, b) + } + + // good path + assert!(batch_verify_internal( + &uni_vk, + &ml_vk, + &com, + &points, + &evaluations, + &batch_proof, + )?); + + // bad commitment + assert!(!batch_verify_internal( + &uni_vk, + &ml_vk, + &Commitment(::G1Affine::default()), + &points, + &evaluations, + &batch_proof, + )?); + + // bad points + assert!( + batch_verify_internal(&uni_vk, &ml_vk, &com, &points[1..], &[], &batch_proof,).is_err() + ); + + // bad proof + assert!(batch_verify_internal( + &uni_vk, + &ml_vk, + &com, + &points, + &evaluations, + &MultilinearKzgBatchProof { + proof: MultilinearKzgProof { proofs: Vec::new() }, + q_x_commit: Commitment(::G1Affine::default()), + q_x_opens: vec![], + }, + ) + .is_err()); + + // bad value + let mut wrong_evals = evaluations.clone(); + wrong_evals[0] = Fr::default(); + assert!(!batch_verify_internal( + &uni_vk, + &ml_vk, + &com, + &points, + &wrong_evals, + &batch_proof + )?); + + // bad q(x) commit + let mut wrong_proof = batch_proof; + wrong_proof.q_x_commit = Commitment(::G1Affine::default()); + assert!(!batch_verify_internal( + &uni_vk, + &ml_vk, + &com, + &points, + &evaluations, + &wrong_proof, + )?); + Ok(()) + } + + #[test] + fn test_multi_open_internal() -> Result<(), PCSError> { + let mut rng = test_rng(); + + let uni_params = + UnivariateUniversalParams::::gen_srs_for_testing(&mut rng, 1usize << 15)?; + let ml_params = MultilinearUniversalParams::::gen_srs_for_testing(&mut rng, 15)?; + for num_poly in 1..10 { + for nv in 1..5 { + let polys1: Vec<_> = (0..num_poly) + .map(|_| Rc::new(DenseMultilinearExtension::rand(nv, &mut rng))) + .collect(); + test_multi_open_helper(&uni_params, &ml_params, &polys1, &mut rng)?; + } + } + + Ok(()) + } +} diff --git a/pcs/src/multilinear_kzg/batching/single_poly.rs b/pcs/src/multilinear_kzg/batching/single_poly.rs new file mode 100644 index 0000000..05f071c --- /dev/null +++ b/pcs/src/multilinear_kzg/batching/single_poly.rs @@ -0,0 +1,368 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +use crate::{ + multilinear_kzg::{ + open_internal, + srs::{MultilinearProverParam, MultilinearVerifierParam}, + util::compute_w_circ_l, + verify_internal, MultilinearKzgBatchProof, + }, + prelude::{Commitment, UnivariateProverParam, UnivariateVerifierParam}, + univariate_kzg::UnivariateKzgPCS, + PCSError, PolynomialCommitmentScheme, +}; +use arithmetic::{build_l, get_uni_domain}; +use ark_ec::PairingEngine; +use ark_poly::{DenseMultilinearExtension, EvaluationDomain, MultilinearExtension, Polynomial}; +use ark_std::{end_timer, format, rc::Rc, start_timer, string::ToString, vec, vec::Vec}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use transcript::IOPTranscript; + +/// Input +/// - the prover parameters for univariate KZG, +/// - the prover parameters for multilinear KZG, +/// - a single MLE, +/// - a commitment to the MLE +/// - and a list of points, +/// compute a multi-opening for this polynomial. +/// +/// For simplicity, this API requires each MLE to have only one point. If +/// the caller wish to use more than one points per MLE, it should be +/// handled at the caller layer. +/// +/// +/// Returns the proof, consists of +/// - the multilinear KZG opening +/// - the univariate KZG commitment to q(x) +/// - the openings and evaluations of q(x) at omega^i and r +/// +/// Steps: +/// 1. build `l(points)` which is a list of univariate polynomials that goes +/// through the points +/// 3. build `q(x)` which is a univariate polynomial `W circ l` +/// 4. commit to q(x) and sample r from transcript +/// transcript contains: w commitment, points, q(x)'s commitment +/// 5. build q(omega^i) and their openings +/// 6. build q(r) and its opening +/// 7. get a point `p := l(r)` +/// 8. output an opening of `w` over point `p` +/// 9. output `w(p)` +pub(crate) fn multi_open_same_poly_internal( + uni_prover_param: &UnivariateProverParam, + ml_prover_param: &MultilinearProverParam, + polynomial: &Rc>, + commitment: &Commitment, + points: &[Vec], +) -> Result<(MultilinearKzgBatchProof, Vec), PCSError> { + let open_timer = start_timer!(|| "multi open"); + + // =================================== + // Sanity checks on inputs + // =================================== + let points_len = points.len(); + if points_len == 0 { + return Err(PCSError::InvalidParameters("points is empty".to_string())); + } + + let num_var = polynomial.num_vars(); + for point in points.iter() { + if point.len() != num_var { + return Err(PCSError::InvalidParameters( + "points do not have same num_vars".to_string(), + )); + } + } + + let domain = get_uni_domain::(points_len)?; + + // 1. build `l(points)` which is a list of univariate polynomials that goes + // through the points + let uni_polys = build_l(points, &domain, false)?; + + // 3. build `q(x)` which is a univariate polynomial `W circ l` + let q_x = compute_w_circ_l(polynomial, &uni_polys, points.len(), false)?; + + // 4. commit to q(x) and sample r from transcript + // transcript contains: w commitment, points, q(x)'s commitment + let mut transcript = IOPTranscript::new(b"ml kzg"); + transcript.append_serializable_element(b"w", commitment)?; + for point in points { + transcript.append_serializable_element(b"w", point)?; + } + + let q_x_commit = UnivariateKzgPCS::::commit(uni_prover_param, &q_x)?; + transcript.append_serializable_element(b"q(x)", &q_x_commit)?; + let r = transcript.get_and_append_challenge(b"r")?; + // 5. build q(omega^i) and their openings + let mut q_x_opens = vec![]; + let mut q_x_evals = vec![]; + for i in 0..points_len { + let (q_x_open, q_x_eval) = + UnivariateKzgPCS::::open(uni_prover_param, &q_x, &domain.element(i))?; + q_x_opens.push(q_x_open); + q_x_evals.push(q_x_eval); + + #[cfg(feature = "extensive_sanity_checks")] + { + // sanity check + let point: Vec = uni_polys + .iter() + .map(|poly| poly.evaluate(&domain.element(i))) + .collect(); + let mle_eval = polynomial.evaluate(&point).unwrap(); + if mle_eval != q_x_eval { + return Err(PCSError::InvalidProver( + "Q(omega) does not match W(l(omega))".to_string(), + )); + } + } + } + + // 6. build q(r) and its opening + let (q_x_open, q_r_value) = UnivariateKzgPCS::::open(uni_prover_param, &q_x, &r)?; + q_x_opens.push(q_x_open); + q_x_evals.push(q_r_value); + + // 7. get a point `p := l(r)` + let point: Vec = uni_polys + .into_par_iter() + .map(|poly| poly.evaluate(&r)) + .collect(); + // 8. output an opening of `w` over point `p` + let (mle_opening, mle_eval) = open_internal(ml_prover_param, polynomial, &point)?; + + // 9. output value that is `w` evaluated at `p` (which should match `q(r)`) + if mle_eval != q_r_value { + return Err(PCSError::InvalidProver( + "Q(r) does not match W(l(r))".to_string(), + )); + } + end_timer!(open_timer); + Ok(( + MultilinearKzgBatchProof { + proof: mle_opening, + q_x_commit, + q_x_opens, + }, + q_x_evals, + )) +} + +/// Verifies that the `multi_commitment` is a valid commitment +/// to a list of MLEs for the given openings and evaluations in +/// the batch_proof. +/// +/// steps: +/// +/// 1. push w, points and q_com into transcript +/// 2. sample `r` from transcript +/// 3. check `q(r) == batch_proof.q_x_value.last` and +/// `q(omega^i) == batch_proof.q_x_value[i]` +/// 4. build `l(points)` which is a list of univariate +/// polynomials that goes through the points +/// 5. get a point `p := l(r)` +/// 6. verifies `p` is valid against multilinear KZG proof +#[allow(dead_code)] +pub(crate) fn batch_verify_same_poly_internal( + uni_verifier_param: &UnivariateVerifierParam, + ml_verifier_param: &MultilinearVerifierParam, + multi_commitment: &Commitment, + points: &[Vec], + values: &[E::Fr], + batch_proof: &MultilinearKzgBatchProof, +) -> Result { + let verify_timer = start_timer!(|| "batch verify"); + + // =================================== + // Sanity checks on inputs + // =================================== + let points_len = points.len(); + if points_len == 0 { + return Err(PCSError::InvalidParameters("points is empty".to_string())); + } + + // add one here because we also have q(r) and its opening + if points_len + 1 != batch_proof.q_x_opens.len() { + return Err(PCSError::InvalidParameters(format!( + "openings length {} does not match point length {}", + points_len + 1, + batch_proof.q_x_opens.len() + ))); + } + + if points_len + 1 != values.len() { + return Err(PCSError::InvalidParameters(format!( + "values length {} does not match point length {}", + points_len + 1, + values.len() + ))); + } + + let num_var = points[0].len(); + for point in points.iter().skip(1) { + if point.len() != num_var { + return Err(PCSError::InvalidParameters(format!( + "points do not have same num_vars ({} vs {})", + point.len(), + num_var, + ))); + } + } + + let domain = get_uni_domain::(points_len)?; + // 1. push w, points and q_com into transcript + let mut transcript = IOPTranscript::new(b"ml kzg"); + transcript.append_serializable_element(b"w", multi_commitment)?; + + for point in points { + transcript.append_serializable_element(b"w", point)?; + } + + transcript.append_serializable_element(b"q(x)", &batch_proof.q_x_commit)?; + // 2. sample `r` from transcript + let r = transcript.get_and_append_challenge(b"r")?; + // 3. check `q(r) == batch_proof.q_x_value.last` and `q(omega^i) = + // batch_proof.q_x_value[i]` + for (i, value) in values.iter().enumerate().take(points_len) { + if !UnivariateKzgPCS::verify( + uni_verifier_param, + &batch_proof.q_x_commit, + &domain.element(i), + value, + &batch_proof.q_x_opens[i], + )? { + #[cfg(debug_assertion)] + println!("q(omega^{}) verification failed", i); + return Ok(false); + } + } + + if !UnivariateKzgPCS::verify( + uni_verifier_param, + &batch_proof.q_x_commit, + &r, + &values[points_len], + &batch_proof.q_x_opens[points_len], + )? { + #[cfg(debug_assertion)] + println!("q(r) verification failed"); + return Ok(false); + } + // 4. build `l(points)` which is a list of univariate polynomials that goes + // through the points + let uni_polys = build_l(points, &domain, false)?; + + // 5. get a point `p := l(r)` + let point: Vec = uni_polys.iter().map(|x| x.evaluate(&r)).collect(); + // 6. verifies `p` is valid against multilinear KZG proof + let res = verify_internal( + ml_verifier_param, + multi_commitment, + &point, + &values[points_len], + &batch_proof.proof, + )?; + #[cfg(debug_assertion)] + if !res { + println!("multilinear KZG verification failed"); + } + + end_timer!(verify_timer); + + Ok(res) +} +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + multilinear_kzg::{ + srs::MultilinearUniversalParams, + util::{compute_qx_degree, generate_evaluations_single_poly}, + MultilinearKzgPCS, + }, + prelude::UnivariateUniversalParams, + StructuredReferenceString, + }; + use ark_bls12_381::Bls12_381 as E; + use ark_ec::PairingEngine; + use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; + use ark_std::{ + log2, + rand::{CryptoRng, RngCore}, + test_rng, + vec::Vec, + UniformRand, + }; + type Fr = ::Fr; + + fn test_same_poly_multi_open_internal_helper( + uni_params: &UnivariateUniversalParams, + ml_params: &MultilinearUniversalParams, + poly: &Rc>, + point_len: usize, + rng: &mut R, + ) -> Result<(), PCSError> { + let nv = poly.num_vars; + let qx_degree = compute_qx_degree(nv, point_len); + let padded_qx_degree = 1usize << log2(qx_degree); + + let (uni_ck, uni_vk) = uni_params.trim(padded_qx_degree)?; + let (ml_ck, ml_vk) = ml_params.trim(nv)?; + + let mut points = Vec::new(); + let mut eval = Vec::new(); + for _ in 0..point_len { + let point = (0..nv).map(|_| Fr::rand(rng)).collect::>(); + eval.push(poly.evaluate(&point).unwrap()); + points.push(point); + } + + let evals = generate_evaluations_single_poly(poly, &points)?; + let com = MultilinearKzgPCS::commit(&(ml_ck.clone(), uni_ck.clone()), poly)?; + let (batch_proof, evaluations) = + multi_open_same_poly_internal(&uni_ck, &ml_ck, poly, &com, &points)?; + + for (a, b) in evals.iter().zip(evaluations.iter()) { + assert_eq!(a, b) + } + + // good path + assert!(batch_verify_same_poly_internal( + &uni_vk, + &ml_vk, + &com, + &points, + &evaluations, + &batch_proof, + )?); + + Ok(()) + } + + #[test] + fn test_same_poly_multi_open_internal() -> Result<(), PCSError> { + let mut rng = test_rng(); + + let uni_params = + UnivariateUniversalParams::::gen_srs_for_testing(&mut rng, 1usize << 15)?; + let ml_params = MultilinearUniversalParams::::gen_srs_for_testing(&mut rng, 15)?; + for nv in 1..10 { + for point_len in 1..10 { + // normal polynomials + let polys1 = Rc::new(DenseMultilinearExtension::rand(nv, &mut rng)); + test_same_poly_multi_open_internal_helper( + &uni_params, + &ml_params, + &polys1, + point_len, + &mut rng, + )?; + } + } + Ok(()) + } +} diff --git a/pcs/src/multilinear_kzg/mod.rs b/pcs/src/multilinear_kzg/mod.rs new file mode 100644 index 0000000..0e88c98 --- /dev/null +++ b/pcs/src/multilinear_kzg/mod.rs @@ -0,0 +1,644 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Main module for multilinear KZG commitment scheme + +mod batching; +pub(crate) mod srs; +pub(crate) mod util; + +use self::batching::{ + batch_verify_internal, batch_verify_same_poly_internal, multi_open_internal, + multi_open_same_poly_internal, +}; +use crate::{ + prelude::{ + Commitment, UnivariateProverParam, UnivariateUniversalParams, UnivariateVerifierParam, + }, + univariate_kzg::UnivariateKzgProof, + PCSError, PolynomialCommitmentScheme, StructuredReferenceString, +}; +use arithmetic::{evaluate_opt, merge_polynomials}; +use ark_ec::{ + msm::{FixedBaseMSM, VariableBaseMSM}, + AffineCurve, PairingEngine, ProjectiveCurve, +}; +use ark_ff::PrimeField; +use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; +use ark_std::{ + borrow::Borrow, + end_timer, format, + marker::PhantomData, + rand::{CryptoRng, RngCore}, + rc::Rc, + start_timer, + string::ToString, + vec, + vec::Vec, + One, Zero, +}; +// use batching::{batch_verify_internal, multi_open_internal}; +use srs::{MultilinearProverParam, MultilinearUniversalParams, MultilinearVerifierParam}; + +/// KZG Polynomial Commitment Scheme on multilinear polynomials. +pub struct MultilinearKzgPCS { + #[doc(hidden)] + phantom: PhantomData, +} + +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug, PartialEq, Eq)] +/// proof of opening +pub struct MultilinearKzgProof { + /// Evaluation of quotients + pub proofs: Vec, +} + +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug, PartialEq, Eq)] +/// proof of batch opening +pub struct MultilinearKzgBatchProof { + /// The actual proof + pub proof: MultilinearKzgProof, + /// Commitment to q(x):= w(l(x)) where + /// - `w` is the merged MLE + /// - `l` is the list of univariate polys that goes through all points + pub q_x_commit: Commitment, + /// openings of q(x) at 1, omega, ..., and r + pub q_x_opens: Vec>, +} + +impl PolynomialCommitmentScheme for MultilinearKzgPCS { + // Parameters + type ProverParam = ( + MultilinearProverParam, + UnivariateProverParam, + ); + type VerifierParam = (MultilinearVerifierParam, UnivariateVerifierParam); + type SRS = (MultilinearUniversalParams, UnivariateUniversalParams); + // Polynomial and its associated types + type Polynomial = Rc>; + type Point = Vec; + type Evaluation = E::Fr; + // Commitments and proofs + type Commitment = Commitment; + type BatchCommitment = Commitment; + type Proof = MultilinearKzgProof; + type BatchProof = MultilinearKzgBatchProof; + + /// Build SRS for testing. + /// + /// - For univariate polynomials, `log_size` is the log of maximum degree. + /// - For multilinear polynomials, `log_size` is the number of variables. + /// + /// WARNING: THIS FUNCTION IS FOR TESTING PURPOSE ONLY. + /// THE OUTPUT SRS SHOULD NOT BE USED IN PRODUCTION. + fn gen_srs_for_testing( + rng: &mut R, + log_size: usize, + ) -> Result { + Ok(( + MultilinearUniversalParams::::gen_srs_for_testing(rng, log_size)?, + UnivariateUniversalParams::::gen_srs_for_testing(rng, 1 << log_size)?, + )) + } + + /// Trim the universal parameters to specialize the public parameters. + /// Input both `supported_log_degree` for univariate and + /// `supported_num_vars` for multilinear. + fn trim( + srs: impl Borrow, + supported_degree: usize, + supported_num_vars: Option, + ) -> Result<(Self::ProverParam, Self::VerifierParam), PCSError> { + let supported_num_vars = match supported_num_vars { + Some(p) => p, + None => { + return Err(PCSError::InvalidParameters( + "multilinear should receive a num_var param".to_string(), + )) + }, + }; + let (uni_ck, uni_vk) = srs.borrow().1.trim(supported_degree)?; + let (ml_ck, ml_vk) = srs.borrow().0.trim(supported_num_vars)?; + + Ok(((ml_ck, uni_ck), (ml_vk, uni_vk))) + } + + /// Generate a commitment for a polynomial. + /// + /// This function takes `2^num_vars` number of scalar multiplications over + /// G1. + fn commit( + prover_param: impl Borrow, + poly: &Self::Polynomial, + ) -> Result { + let prover_param = prover_param.borrow(); + let commit_timer = start_timer!(|| "commit"); + if prover_param.0.num_vars < poly.num_vars { + return Err(PCSError::InvalidParameters(format!( + "MlE length ({}) exceeds param limit ({})", + poly.num_vars, prover_param.0.num_vars + ))); + } + let ignored = prover_param.0.num_vars - poly.num_vars; + let scalars: Vec<_> = poly + .to_evaluations() + .into_iter() + .map(|x| x.into_repr()) + .collect(); + let commitment = VariableBaseMSM::multi_scalar_mul( + &prover_param.0.powers_of_g[ignored].evals, + scalars.as_slice(), + ) + .into_affine(); + + end_timer!(commit_timer); + Ok(Commitment(commitment)) + } + + /// Generate a commitment for a list of polynomials. + /// + /// This function takes `2^(num_vars + log(polys.len())` number of scalar + /// multiplications over G1. + fn multi_commit( + prover_param: impl Borrow, + polys: &[Self::Polynomial], + ) -> Result { + let prover_param = prover_param.borrow(); + let commit_timer = start_timer!(|| "multi commit"); + let poly = merge_polynomials(polys)?; + + let scalars: Vec<_> = poly + .to_evaluations() + .iter() + .map(|x| x.into_repr()) + .collect(); + + let commitment = VariableBaseMSM::multi_scalar_mul( + &prover_param.0.powers_of_g[0].evals, + scalars.as_slice(), + ) + .into_affine(); + + end_timer!(commit_timer); + Ok(Commitment(commitment)) + } + + /// On input a polynomial `p` and a point `point`, outputs a proof for the + /// same. This function does not need to take the evaluation value as an + /// input. + /// + /// This function takes 2^{num_var +1} number of scalar multiplications over + /// G1: + /// - it prodceeds with `num_var` number of rounds, + /// - at round i, we compute an MSM for `2^{num_var - i + 1}` number of G2 + /// elements. + fn open( + prover_param: impl Borrow, + polynomial: &Self::Polynomial, + point: &Self::Point, + ) -> Result<(Self::Proof, Self::Evaluation), PCSError> { + open_internal(&prover_param.borrow().0, polynomial, point) + } + + /// Input + /// - the prover parameters for univariate KZG, + /// - the prover parameters for multilinear KZG, + /// - a list of multilinear extensions (MLEs), + /// - a commitment to all multilinear extensions, + /// - and a same number of points, + /// compute a multi-opening for all the polynomials. + /// + /// For simplicity, this API requires each MLE to have only one point. If + /// the caller wish to use more than one points per MLE, it should be + /// handled at the caller layer. + /// + /// Returns an error if the lengths do not match. + /// + /// Returns the proof, consists of + /// - the multilinear KZG opening + /// - the univariate KZG commitment to q(x) + /// - the openings and evaluations of q(x) at omega^i and r + /// + /// Steps: + /// 1. build `l(points)` which is a list of univariate polynomials that goes + /// through the points + /// 2. build MLE `w` which is the merge of all MLEs. + /// 3. build `q(x)` which is a univariate polynomial `W circ l` + /// 4. commit to q(x) and sample r from transcript + /// transcript contains: w commitment, points, q(x)'s commitment + /// 5. build q(omega^i) and their openings + /// 6. build q(r) and its opening + /// 7. get a point `p := l(r)` + /// 8. output an opening of `w` over point `p` + /// 9. output `w(p)` + fn multi_open( + prover_param: impl Borrow, + multi_commitment: &Self::BatchCommitment, + polynomials: &[Self::Polynomial], + points: &[Self::Point], + ) -> Result<(Self::BatchProof, Vec), PCSError> { + multi_open_internal::( + &prover_param.borrow().1, + &prover_param.borrow().0, + polynomials, + multi_commitment, + points, + ) + } + + /// Input a multilinear extension, and a number of points, and + /// a transcript, compute a multi-opening for all the polynomials. + fn multi_open_single_poly( + prover_param: impl Borrow, + commitment: &Self::Commitment, + polynomial: &Self::Polynomial, + points: &[Self::Point], + ) -> Result<(Self::BatchProof, Vec), PCSError> { + multi_open_same_poly_internal::( + &prover_param.borrow().1, + &prover_param.borrow().0, + polynomial, + commitment, + points, + ) + } + + /// Verifies that `value` is the evaluation at `x` of the polynomial + /// committed inside `comm`. + /// + /// This function takes + /// - num_var number of pairing product. + /// - num_var number of MSM + fn verify( + verifier_param: &Self::VerifierParam, + commitment: &Self::Commitment, + point: &Self::Point, + value: &E::Fr, + proof: &Self::Proof, + ) -> Result { + verify_internal(&verifier_param.0, commitment, point, value, proof) + } + + /// Verifies that `value` is the evaluation at `x_i` of the polynomial + /// `poly_i` committed inside `comm`. + /// steps: + /// + /// 1. put `q(x)`'s evaluations over `(1, omega,...)` into transcript + /// 2. sample `r` from transcript + /// 3. check `q(r) == value` + /// 4. build `l(points)` which is a list of univariate polynomials that goes + /// through the points + /// 5. get a point `p := l(r)` + /// 6. verifies `p` is verifies against proof + fn batch_verify( + verifier_param: &Self::VerifierParam, + multi_commitment: &Self::BatchCommitment, + points: &[Self::Point], + values: &[E::Fr], + batch_proof: &Self::BatchProof, + _rng: &mut R, + ) -> Result { + batch_verify_internal( + &verifier_param.1, + &verifier_param.0, + multi_commitment, + points, + values, + batch_proof, + ) + } + + /// Verifies that `value_i` is the evaluation at `x_i` of the polynomial + /// `poly` committed inside `comm`. + fn batch_verify_single_poly( + verifier_param: &Self::VerifierParam, + commitment: &Self::Commitment, + points: &[Self::Point], + values: &[E::Fr], + batch_proof: &Self::BatchProof, + ) -> Result { + batch_verify_same_poly_internal( + &verifier_param.1, + &verifier_param.0, + commitment, + points, + values, + batch_proof, + ) + } +} + +/// On input a polynomial `p` and a point `point`, outputs a proof for the +/// same. This function does not need to take the evaluation value as an +/// input. +/// +/// This function takes 2^{num_var +1} number of scalar multiplications over +/// G1: +/// - it proceeds with `num_var` number of rounds, +/// - at round i, we compute an MSM for `2^{num_var - i + 1}` number of G2 +/// elements. +fn open_internal( + prover_param: &MultilinearProverParam, + polynomial: &DenseMultilinearExtension, + point: &[E::Fr], +) -> Result<(MultilinearKzgProof, E::Fr), PCSError> { + let open_timer = start_timer!(|| format!("open mle with {} variable", polynomial.num_vars)); + + if polynomial.num_vars() > prover_param.num_vars { + return Err(PCSError::InvalidParameters(format!( + "Polynomial num_vars {} exceed the limit {}", + polynomial.num_vars, prover_param.num_vars + ))); + } + + if polynomial.num_vars() != point.len() { + return Err(PCSError::InvalidParameters(format!( + "Polynomial num_vars {} does not match point len {}", + polynomial.num_vars, + point.len() + ))); + } + + let nv = polynomial.num_vars(); + let ignored = prover_param.num_vars - nv; + let mut r: Vec> = (0..nv + 1).map(|_| Vec::new()).collect(); + let mut q: Vec> = (0..nv + 1).map(|_| Vec::new()).collect(); + + r[nv] = polynomial.to_evaluations(); + + let mut proofs = Vec::new(); + + for (i, (&point_at_k, gi)) in point + .iter() + .zip(prover_param.powers_of_g[ignored..].iter()) + .take(nv) + .enumerate() + { + let ith_round = start_timer!(|| format!("{}-th round", i)); + + let k = nv - i; + let cur_dim = 1 << (k - 1); + let mut cur_q = vec![E::Fr::zero(); cur_dim]; + let mut cur_r = vec![E::Fr::zero(); cur_dim]; + let one_minus_point_at_k = E::Fr::one() - point_at_k; + + let ith_round_eval = start_timer!(|| format!("{}-th round eval", i)); + for b in 0..(1 << (k - 1)) { + // q_b = pre_r [2^b + 1] - pre_r [2^b] + cur_q[b] = r[k][(b << 1) + 1] - r[k][b << 1]; + + // r_b = pre_r [2^b]*(1-p) + pre_r [2^b + 1] * p + cur_r[b] = r[k][b << 1] * one_minus_point_at_k + (r[k][(b << 1) + 1] * point_at_k); + } + end_timer!(ith_round_eval); + let scalars: Vec<_> = (0..(1 << k)).map(|x| cur_q[x >> 1].into_repr()).collect(); + + q[k] = cur_q; + r[k - 1] = cur_r; + + // this is a MSM over G1 and is likely to be the bottleneck + proofs.push(VariableBaseMSM::multi_scalar_mul(&gi.evals, &scalars).into_affine()); + end_timer!(ith_round); + } + let eval = evaluate_opt(polynomial, point); + end_timer!(open_timer); + Ok((MultilinearKzgProof { proofs }, eval)) +} + +/// Verifies that `value` is the evaluation at `x` of the polynomial +/// committed inside `comm`. +/// +/// This function takes +/// - num_var number of pairing product. +/// - num_var number of MSM +fn verify_internal( + verifier_param: &MultilinearVerifierParam, + commitment: &Commitment, + point: &[E::Fr], + value: &E::Fr, + proof: &MultilinearKzgProof, +) -> Result { + let verify_timer = start_timer!(|| "verify"); + let num_var = point.len(); + + if num_var > verifier_param.num_vars { + return Err(PCSError::InvalidParameters(format!( + "point length ({}) exceeds param limit ({})", + num_var, verifier_param.num_vars + ))); + } + + let ignored = verifier_param.num_vars - num_var; + let prepare_inputs_timer = start_timer!(|| "prepare pairing inputs"); + + let scalar_size = E::Fr::size_in_bits(); + let window_size = FixedBaseMSM::get_mul_window_size(num_var); + + let h_table = FixedBaseMSM::get_window_table( + scalar_size, + window_size, + verifier_param.h.into_projective(), + ); + let h_mul: Vec = + FixedBaseMSM::multi_scalar_mul(scalar_size, window_size, &h_table, point); + + let h_vec: Vec<_> = (0..num_var) + .map(|i| verifier_param.h_mask[ignored + i].into_projective() - h_mul[i]) + .collect(); + let h_vec: Vec = E::G2Projective::batch_normalization_into_affine(&h_vec); + end_timer!(prepare_inputs_timer); + + let pairing_product_timer = start_timer!(|| "pairing product"); + + let mut pairings: Vec<_> = proof + .proofs + .iter() + .map(|&x| E::G1Prepared::from(x)) + .zip(h_vec.into_iter().take(num_var).map(E::G2Prepared::from)) + .collect(); + + pairings.push(( + E::G1Prepared::from( + (verifier_param.g.mul(*value) - commitment.0.into_projective()).into_affine(), + ), + E::G2Prepared::from(verifier_param.h), + )); + + let res = E::product_of_pairings(pairings.iter()) == E::Fqk::one(); + + end_timer!(pairing_product_timer); + end_timer!(verify_timer); + Ok(res) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::multilinear_kzg::util::compute_qx_degree; + use arithmetic::get_batched_nv; + use ark_bls12_381::Bls12_381; + use ark_ec::PairingEngine; + use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; + use ark_std::{log2, rand::RngCore, test_rng, vec::Vec, UniformRand}; + + type E = Bls12_381; + type Fr = ::Fr; + + fn test_single_helper( + params: &(MultilinearUniversalParams, UnivariateUniversalParams), + poly: &Rc>, + rng: &mut R, + ) -> Result<(), PCSError> { + let nv = poly.num_vars(); + assert_ne!(nv, 0); + let uni_degree = 1; + let (ck, vk) = MultilinearKzgPCS::trim(params, uni_degree, Some(nv + 1))?; + let point: Vec<_> = (0..nv).map(|_| Fr::rand(rng)).collect(); + let com = MultilinearKzgPCS::commit(&ck, poly)?; + let (proof, value) = MultilinearKzgPCS::open(&ck, poly, &point)?; + + assert!(MultilinearKzgPCS::verify( + &vk, &com, &point, &value, &proof + )?); + + let value = Fr::rand(rng); + assert!(!MultilinearKzgPCS::verify( + &vk, &com, &point, &value, &proof + )?); + + Ok(()) + } + + #[test] + fn test_single_commit() -> Result<(), PCSError> { + let mut rng = test_rng(); + + let params = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 10)?; + + // normal polynomials + let poly1 = Rc::new(DenseMultilinearExtension::rand(8, &mut rng)); + test_single_helper(¶ms, &poly1, &mut rng)?; + + // single-variate polynomials + let poly2 = Rc::new(DenseMultilinearExtension::rand(1, &mut rng)); + test_single_helper(¶ms, &poly2, &mut rng)?; + + Ok(()) + } + + #[test] + fn setup_commit_verify_constant_polynomial() { + let mut rng = test_rng(); + + // normal polynomials + assert!(MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 0).is_err()); + } + + fn test_multi_open_single_poly_helper( + params: &(MultilinearUniversalParams, UnivariateUniversalParams), + poly: Rc>, + num_open: usize, + rng: &mut R, + ) -> Result<(), PCSError> { + let nv = poly.num_vars(); + assert_ne!(nv, 0); + let uni_degree = 1024; + let (ck, vk) = MultilinearKzgPCS::trim(params, uni_degree, Some(nv + 1))?; + let mut points = vec![]; + for _ in 0..num_open { + let point: Vec<_> = (0..nv).map(|_| Fr::rand(rng)).collect(); + points.push(point) + } + let com = MultilinearKzgPCS::commit(&ck, &poly)?; + let (proof, mut values) = + MultilinearKzgPCS::multi_open_single_poly(&ck, &com, &poly, &points)?; + for (a, b) in values.iter().zip(points.iter()) { + let p = poly.evaluate(&b).unwrap(); + assert_eq!(*a, p); + } + + assert!(MultilinearKzgPCS::batch_verify_single_poly( + &vk, &com, &points, &values, &proof + )?); + + values[0] = Fr::rand(rng); + assert!(!MultilinearKzgPCS::batch_verify_single_poly( + &vk, &com, &points, &values, &proof + )?); + Ok(()) + } + + #[test] + fn test_multi_open_single_poly() -> Result<(), PCSError> { + let mut rng = test_rng(); + + let params = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 15)?; + + for nv in 1..10 { + for num_open in 2..10 { + let poly1 = Rc::new(DenseMultilinearExtension::rand(nv, &mut rng)); + test_multi_open_single_poly_helper(¶ms, poly1, num_open, &mut rng)?; + } + } + + Ok(()) + } + + fn test_multi_open_helper( + params: &(MultilinearUniversalParams, UnivariateUniversalParams), + polys: &[Rc>], + num_open: usize, + rng: &mut R, + ) -> Result<(), PCSError> { + let nv = polys[0].num_vars(); + assert_ne!(nv, 0); + let merged_nv = get_batched_nv(nv, polys.len()); + let qx_degree = compute_qx_degree(merged_nv, polys.len()); + let padded_qx_degree = 1usize << log2(qx_degree); + + let (ck, vk) = MultilinearKzgPCS::trim(params, padded_qx_degree, Some(merged_nv))?; + let mut points = vec![]; + for _ in 0..num_open { + let point: Vec<_> = (0..nv).map(|_| Fr::rand(rng)).collect(); + points.push(point) + } + let com = MultilinearKzgPCS::multi_commit(&ck, &polys)?; + let (proof, mut values) = MultilinearKzgPCS::multi_open(&ck, &com, polys, &points)?; + + assert!(MultilinearKzgPCS::batch_verify( + &vk, &com, &points, &values, &proof, rng + )?); + + values[0] = Fr::rand(rng); + assert!(!MultilinearKzgPCS::batch_verify_single_poly( + &vk, &com, &points, &values, &proof + )?); + + Ok(()) + } + + #[test] + fn test_multi_open() -> Result<(), PCSError> { + let mut rng = test_rng(); + + let params = MultilinearKzgPCS::::gen_srs_for_testing(&mut rng, 15)?; + + // normal polynomials + for nv in 1..10 { + for num_open in 1..4 { + let mut polys = vec![]; + for _ in 0..num_open { + let poly = Rc::new(DenseMultilinearExtension::rand(nv, &mut rng)); + polys.push(poly) + } + + test_multi_open_helper(¶ms, &polys, num_open, &mut rng)?; + } + } + Ok(()) + } +} diff --git a/pcs/src/multilinear_kzg/srs.rs b/pcs/src/multilinear_kzg/srs.rs new file mode 100644 index 0000000..c24fbdd --- /dev/null +++ b/pcs/src/multilinear_kzg/srs.rs @@ -0,0 +1,273 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Implementing Structured Reference Strings for multilinear polynomial KZG +use crate::{prelude::PCSError, StructuredReferenceString}; +use ark_ec::{msm::FixedBaseMSM, AffineCurve, PairingEngine, ProjectiveCurve}; +use ark_ff::{Field, PrimeField}; +use ark_poly::DenseMultilinearExtension; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; +use ark_std::{ + collections::LinkedList, + end_timer, format, + rand::{CryptoRng, RngCore}, + start_timer, + string::ToString, + vec::Vec, + UniformRand, +}; +use core::iter::FromIterator; + +/// Evaluations over {0,1}^n for G1 or G2 +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug)] +pub struct Evaluations { + /// The evaluations. + pub evals: Vec, +} + +/// Universal Parameter +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug)] +pub struct MultilinearUniversalParams { + /// prover parameters + pub prover_param: MultilinearProverParam, + /// h^randomness: h^t1, h^t2, ..., **h^{t_nv}** + pub h_mask: Vec, +} + +/// Prover Parameters +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug)] +pub struct MultilinearProverParam { + /// number of variables + pub num_vars: usize, + /// `pp_{num_vars}`, `pp_{num_vars - 1}`, `pp_{num_vars - 2}`, ..., defined + /// by XZZPD19 + pub powers_of_g: Vec>, + /// generator for G1 + pub g: E::G1Affine, + /// generator for G2 + pub h: E::G2Affine, +} + +/// Verifier Parameters +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug)] +pub struct MultilinearVerifierParam { + /// number of variables + pub num_vars: usize, + /// generator of G1 + pub g: E::G1Affine, + /// generator of G2 + pub h: E::G2Affine, + /// h^randomness: h^t1, h^t2, ..., **h^{t_nv}** + pub h_mask: Vec, +} + +impl StructuredReferenceString for MultilinearUniversalParams { + type ProverParam = MultilinearProverParam; + type VerifierParam = MultilinearVerifierParam; + + /// Extract the prover parameters from the public parameters. + fn extract_prover_param(&self, supported_num_vars: usize) -> Self::ProverParam { + let to_reduce = self.prover_param.num_vars - supported_num_vars; + + Self::ProverParam { + powers_of_g: self.prover_param.powers_of_g[to_reduce..].to_vec(), + g: self.prover_param.g, + h: self.prover_param.h, + num_vars: supported_num_vars, + } + } + + /// Extract the verifier parameters from the public parameters. + fn extract_verifier_param(&self, supported_num_vars: usize) -> Self::VerifierParam { + let to_reduce = self.prover_param.num_vars - supported_num_vars; + Self::VerifierParam { + num_vars: supported_num_vars, + g: self.prover_param.g, + h: self.prover_param.h, + h_mask: self.h_mask[to_reduce..].to_vec(), + } + } + + /// Trim the universal parameters to specialize the public parameters + /// for multilinear polynomials to the given `supported_num_vars`, and + /// returns committer key and verifier key. `supported_num_vars` should + /// be in range `1..=params.num_vars` + fn trim( + &self, + supported_num_vars: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), PCSError> { + if supported_num_vars > self.prover_param.num_vars { + return Err(PCSError::InvalidParameters(format!( + "SRS does not support target number of vars {}", + supported_num_vars + ))); + } + + let to_reduce = self.prover_param.num_vars - supported_num_vars; + let ck = Self::ProverParam { + powers_of_g: self.prover_param.powers_of_g[to_reduce..].to_vec(), + g: self.prover_param.g, + h: self.prover_param.h, + num_vars: supported_num_vars, + }; + let vk = Self::VerifierParam { + num_vars: supported_num_vars, + g: self.prover_param.g, + h: self.prover_param.h, + h_mask: self.h_mask[to_reduce..].to_vec(), + }; + Ok((ck, vk)) + } + + /// Build SRS for testing. + /// WARNING: THIS FUNCTION IS FOR TESTING PURPOSE ONLY. + /// THE OUTPUT SRS SHOULD NOT BE USED IN PRODUCTION. + fn gen_srs_for_testing( + rng: &mut R, + num_vars: usize, + ) -> Result { + if num_vars == 0 { + return Err(PCSError::InvalidParameters( + "constant polynomial not supported".to_string(), + )); + } + + let total_timer = start_timer!(|| "SRS generation"); + + let pp_generation_timer = start_timer!(|| "Prover Param generation"); + + let g = E::G1Projective::rand(rng); + let h = E::G2Projective::rand(rng); + + let mut powers_of_g = Vec::new(); + + let t: Vec<_> = (0..num_vars).map(|_| E::Fr::rand(rng)).collect(); + let scalar_bits = E::Fr::size_in_bits(); + + let mut eq: LinkedList> = + LinkedList::from_iter(eq_extension(&t).into_iter()); + let mut eq_arr = LinkedList::new(); + let mut base = eq.pop_back().unwrap().evaluations; + + for i in (0..num_vars).rev() { + eq_arr.push_front(remove_dummy_variable(&base, i)?); + if i != 0 { + let mul = eq.pop_back().unwrap().evaluations; + base = base + .into_iter() + .zip(mul.into_iter()) + .map(|(a, b)| a * b) + .collect(); + } + } + + let mut pp_powers = Vec::new(); + let mut total_scalars = 0; + for i in 0..num_vars { + let eq = eq_arr.pop_front().unwrap(); + let pp_k_powers = (0..(1 << (num_vars - i))).map(|x| eq[x]); + pp_powers.extend(pp_k_powers); + total_scalars += 1 << (num_vars - i); + } + let window_size = FixedBaseMSM::get_mul_window_size(total_scalars); + let g_table = FixedBaseMSM::get_window_table(scalar_bits, window_size, g); + + let pp_g = E::G1Projective::batch_normalization_into_affine( + &FixedBaseMSM::multi_scalar_mul(scalar_bits, window_size, &g_table, &pp_powers), + ); + + let mut start = 0; + for i in 0..num_vars { + let size = 1 << (num_vars - i); + let pp_k_g = Evaluations { + evals: pp_g[start..(start + size)].to_vec(), + }; + powers_of_g.push(pp_k_g); + start += size; + } + + let pp = Self::ProverParam { + num_vars, + g: g.into_affine(), + h: h.into_affine(), + powers_of_g, + }; + + end_timer!(pp_generation_timer); + + let vp_generation_timer = start_timer!(|| "VP generation"); + let h_mask = { + let window_size = FixedBaseMSM::get_mul_window_size(num_vars); + let h_table = FixedBaseMSM::get_window_table(scalar_bits, window_size, h); + E::G2Projective::batch_normalization_into_affine(&FixedBaseMSM::multi_scalar_mul( + scalar_bits, + window_size, + &h_table, + &t, + )) + }; + end_timer!(vp_generation_timer); + end_timer!(total_timer); + Ok(Self { + prover_param: pp, + h_mask, + }) + } +} + +/// fix first `pad` variables of `poly` represented in evaluation form to zero +fn remove_dummy_variable(poly: &[F], pad: usize) -> Result, PCSError> { + if pad == 0 { + return Ok(poly.to_vec()); + } + if !poly.len().is_power_of_two() { + return Err(PCSError::InvalidParameters( + "Size of polynomial should be power of two.".to_string(), + )); + } + let nv = ark_std::log2(poly.len()) as usize - pad; + Ok((0..(1 << nv)).map(|x| poly[x << pad]).collect()) +} + +/// Generate eq(t,x), a product of multilinear polynomials with fixed t. +/// eq(a,b) is takes extensions of a,b in {0,1}^num_vars such that if a and b in +/// {0,1}^num_vars are equal then this polynomial evaluates to 1. +fn eq_extension(t: &[F]) -> Vec> { + let start = start_timer!(|| "eq extension"); + + let dim = t.len(); + let mut result = Vec::new(); + for (i, &ti) in t.iter().enumerate().take(dim) { + let mut poly = Vec::with_capacity(1 << dim); + for x in 0..(1 << dim) { + let xi = if x >> i & 1 == 1 { F::one() } else { F::zero() }; + let ti_xi = ti * xi; + poly.push(ti_xi + ti_xi - xi - ti + F::one()); + } + result.push(DenseMultilinearExtension::from_evaluations_vec(dim, poly)); + } + + end_timer!(start); + result +} + +#[cfg(test)] +mod tests { + use super::*; + use ark_bls12_381::Bls12_381; + use ark_std::test_rng; + type E = Bls12_381; + + #[test] + fn test_srs_gen() -> Result<(), PCSError> { + let mut rng = test_rng(); + for nv in 4..10 { + let _ = MultilinearUniversalParams::::gen_srs_for_testing(&mut rng, nv)?; + } + + Ok(()) + } +} diff --git a/pcs/src/multilinear_kzg/util.rs b/pcs/src/multilinear_kzg/util.rs new file mode 100644 index 0000000..a475b3f --- /dev/null +++ b/pcs/src/multilinear_kzg/util.rs @@ -0,0 +1,432 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Useful utilities for KZG PCS +use crate::prelude::PCSError; +use arithmetic::evaluate_no_par; +use ark_ff::PrimeField; +use ark_poly::{ + univariate::DensePolynomial, DenseMultilinearExtension, EvaluationDomain, Evaluations, + MultilinearExtension, Polynomial, Radix2EvaluationDomain, +}; +use ark_std::{end_timer, format, log2, start_timer, string::ToString, vec::Vec}; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + +/// For an MLE w with `mle_num_vars` variables, and `point_len` number of +/// points, compute the degree of the univariate polynomial `q(x):= w(l(x))` +/// where l(x) is a list of polynomials that go through all points. +// uni_degree is computed as `mle_num_vars * point_len`: +// - each l(x) is of degree `point_len` +// - mle has degree one +// - worst case is `\prod_{i=0}^{mle_num_vars-1} l_i(x) < point_len * mle_num_vars` +#[inline] +pub fn compute_qx_degree(mle_num_vars: usize, point_len: usize) -> usize { + mle_num_vars * ((1 << log2(point_len)) - 1) + 1 +} + +/// Compute W \circ l. +/// +/// Given an MLE W, and a list of univariate polynomials l, generate the +/// univariate polynomial that composes W with l. +/// +/// Returns an error if l's length does not matches number of variables in W. +pub(crate) fn compute_w_circ_l( + w: &DenseMultilinearExtension, + l: &[DensePolynomial], + num_points: usize, + with_suffix: bool, +) -> Result, PCSError> { + let timer = start_timer!(|| "compute W \\circ l"); + + if w.num_vars != l.len() { + return Err(PCSError::InvalidParameters(format!( + "l's length ({}) does not match num_variables ({})", + l.len(), + w.num_vars(), + ))); + } + let uni_degree = if with_suffix { + compute_qx_degree(w.num_vars() + log2(num_points) as usize, num_points) + } else { + compute_qx_degree(w.num_vars(), num_points) + }; + + let domain = match Radix2EvaluationDomain::::new(uni_degree) { + Some(p) => p, + None => { + return Err(PCSError::InvalidParameters( + "failed to build radix 2 domain".to_string(), + )) + }, + }; + + let step = start_timer!(|| format!("compute eval {}-dim domain", domain.size())); + let res_eval = (0..domain.size()) + .into_par_iter() + .map(|i| { + let l_eval: Vec = l.iter().map(|x| x.evaluate(&domain.element(i))).collect(); + evaluate_no_par(w, &l_eval) + }) + .collect(); + end_timer!(step); + + let evaluation = Evaluations::from_vec_and_domain(res_eval, domain); + let res = evaluation.interpolate(); + + end_timer!(timer); + Ok(res) +} + +/// Input a list of multilinear polynomials and a list of points, +/// generate a list of evaluations. +// Note that this function is only used for testing verifications. +// In practice verifier does not see polynomials, and the `mle_values` +// are included in the `batch_proof`. +#[cfg(test)] +pub(crate) fn generate_evaluations_multi_poly( + polynomials: &[std::rc::Rc>], + points: &[Vec], +) -> Result, PCSError> { + use arithmetic::{build_l, get_uni_domain, merge_polynomials}; + + if polynomials.len() != points.len() { + return Err(PCSError::InvalidParameters( + "polynomial length does not match point length".to_string(), + )); + } + let uni_poly_degree = points.len(); + let merge_poly = merge_polynomials(polynomials)?; + + let domain = get_uni_domain::(uni_poly_degree)?; + let uni_polys = build_l(points, &domain, true)?; + let mut mle_values = vec![]; + + for i in 0..uni_poly_degree { + let point: Vec = uni_polys + .iter() + .map(|poly| poly.evaluate(&domain.element(i))) + .collect(); + + let mle_value = merge_poly.evaluate(&point).unwrap(); + mle_values.push(mle_value) + } + Ok(mle_values) +} + +/// Input a list of multilinear polynomials and a list of points, +/// generate a list of evaluations. +// Note that this function is only used for testing verifications. +// In practice verifier does not see polynomials, and the `mle_values` +// are included in the `batch_proof`. +#[cfg(test)] +pub(crate) fn generate_evaluations_single_poly( + polynomial: &std::rc::Rc>, + points: &[Vec], +) -> Result, PCSError> { + use arithmetic::{build_l, get_uni_domain}; + + let uni_poly_degree = points.len(); + + let domain = get_uni_domain::(uni_poly_degree)?; + let uni_polys = build_l(points, &domain, false)?; + let mut mle_values = vec![]; + + for i in 0..uni_poly_degree { + let point: Vec = uni_polys + .iter() + .map(|poly| poly.evaluate(&domain.element(i))) + .collect(); + + let mle_value = polynomial.evaluate(&point).unwrap(); + mle_values.push(mle_value) + } + Ok(mle_values) +} + +#[cfg(test)] +mod test { + use super::*; + use arithmetic::{build_l, get_uni_domain, merge_polynomials}; + use ark_bls12_381::Fr; + use ark_poly::UVPolynomial; + use ark_std::{One, Zero}; + use std::rc::Rc; + + #[test] + fn test_w_circ_l() -> Result<(), PCSError> { + test_w_circ_l_helper::() + } + + fn test_w_circ_l_helper() -> Result<(), PCSError> { + { + // Example from page 53: + // W = 3x1x2 + 2x2 whose evaluations are + // 0, 0 |-> 0 + // 1, 0 |-> 0 + // 0, 1 |-> 2 + // 1, 1 |-> 5 + let w_eval = vec![F::zero(), F::zero(), F::from(2u64), F::from(5u64)]; + let w = DenseMultilinearExtension::from_evaluations_vec(2, w_eval); + + // l0 = t + 2 + // l1 = -2t + 4 + let l0 = DensePolynomial::from_coefficients_vec(vec![F::from(2u64), F::one()]); + let l1 = DensePolynomial::from_coefficients_vec(vec![F::from(4u64), -F::from(2u64)]); + + // res = -6t^2 - 4t + 32 + let res = compute_w_circ_l(&w, [l0, l1].as_ref(), 4, false)?; + let res_rec = DensePolynomial::from_coefficients_vec(vec![ + F::from(32u64), + -F::from(4u64), + -F::from(6u64), + ]); + assert_eq!(res, res_rec); + } + { + // A random example + // W = x1x2x3 - 2x1x2 + 3x2x3 - 4x1x3 + 5x1 - 6x2 + 7x3 + // 0, 0, 0 |-> 0 + // 1, 0, 0 |-> 5 + // 0, 1, 0 |-> -6 + // 1, 1, 0 |-> -3 + // 0, 0, 1 |-> 7 + // 1, 0, 1 |-> 8 + // 0, 1, 1 |-> 4 + // 1, 1, 1 |-> 4 + let w_eval = vec![ + F::zero(), + F::from(5u64), + -F::from(6u64), + -F::from(3u64), + F::from(7u64), + F::from(8u64), + F::from(4u64), + F::from(4u64), + ]; + let w = DenseMultilinearExtension::from_evaluations_vec(3, w_eval); + + // l0 = t + 2 + // l1 = 3t - 4 + // l2 = -5t + 6 + let l0 = DensePolynomial::from_coefficients_vec(vec![F::from(2u64), F::one()]); + let l1 = DensePolynomial::from_coefficients_vec(vec![-F::from(4u64), F::from(3u64)]); + let l2 = DensePolynomial::from_coefficients_vec(vec![F::from(6u64), -F::from(5u64)]); + let res = compute_w_circ_l(&w, [l0, l1, l2].as_ref(), 8, false)?; + + // res = -15t^3 - 23t^2 + 130t - 76 + let res_rec = DensePolynomial::from_coefficients_vec(vec![ + -F::from(76u64), + F::from(130u64), + -F::from(23u64), + -F::from(15u64), + ]); + + assert_eq!(res, res_rec); + } + Ok(()) + } + + #[test] + fn test_w_circ_l_with_prefix() -> Result<(), PCSError> { + test_w_circ_l_with_prefix_helper::() + } + + fn test_w_circ_l_with_prefix_helper() -> Result<(), PCSError> { + { + // Example from page 53: + // W = 3x1x2 + 2x2 whose evaluations are + // 0, 0 |-> 0 + // 1, 0 |-> 0 + // 0, 1 |-> 2 + // 1, 1 |-> 5 + let w_eval = vec![F::zero(), F::zero(), F::from(2u64), F::from(5u64)]; + let w = DenseMultilinearExtension::from_evaluations_vec(2, w_eval); + + // l0 = t + 2 + // l1 = -2t + 4 + let l0 = DensePolynomial::from_coefficients_vec(vec![F::from(2u64), F::one()]); + let l1 = DensePolynomial::from_coefficients_vec(vec![F::from(4u64), -F::from(2u64)]); + + // res = -6t^2 - 4t + 32 + let res = compute_w_circ_l(&w, [l0, l1].as_ref(), 4, true)?; + let res_rec = DensePolynomial::from_coefficients_vec(vec![ + F::from(32u64), + -F::from(4u64), + -F::from(6u64), + ]); + assert_eq!(res, res_rec); + } + { + // A random example + // W = x1x2x3 - 2x1x2 + 3x2x3 - 4x1x3 + 5x1 - 6x2 + 7x3 + // 0, 0, 0 |-> 0 + // 1, 0, 0 |-> 5 + // 0, 1, 0 |-> -6 + // 1, 1, 0 |-> -3 + // 0, 0, 1 |-> 7 + // 1, 0, 1 |-> 8 + // 0, 1, 1 |-> 4 + // 1, 1, 1 |-> 4 + let w_eval = vec![ + F::zero(), + F::from(5u64), + -F::from(6u64), + -F::from(3u64), + F::from(7u64), + F::from(8u64), + F::from(4u64), + F::from(4u64), + ]; + let w = DenseMultilinearExtension::from_evaluations_vec(3, w_eval); + + // l0 = t + 2 + // l1 = 3t - 4 + // l2 = -5t + 6 + let l0 = DensePolynomial::from_coefficients_vec(vec![F::from(2u64), F::one()]); + let l1 = DensePolynomial::from_coefficients_vec(vec![-F::from(4u64), F::from(3u64)]); + let l2 = DensePolynomial::from_coefficients_vec(vec![F::from(6u64), -F::from(5u64)]); + let res = compute_w_circ_l(&w, [l0, l1, l2].as_ref(), 8, true)?; + + // res = -15t^3 - 23t^2 + 130t - 76 + let res_rec = DensePolynomial::from_coefficients_vec(vec![ + -F::from(76u64), + F::from(130u64), + -F::from(23u64), + -F::from(15u64), + ]); + + assert_eq!(res, res_rec); + } + Ok(()) + } + + #[test] + fn test_qx() -> Result<(), PCSError> { + // Example from page 53: + // W1 = 3x1x2 + 2x2 + let w_eval = vec![Fr::zero(), Fr::from(2u64), Fr::zero(), Fr::from(5u64)]; + let w = Rc::new(DenseMultilinearExtension::from_evaluations_vec(2, w_eval)); + + let r = Fr::from(42u64); + + // point 1 is [1, 2] + let point1 = vec![Fr::from(1u64), Fr::from(2u64)]; + + // point 2 is [3, 4] + let point2 = vec![Fr::from(3u64), Fr::from(4u64)]; + + // point 3 is [5, 6] + let point3 = vec![Fr::from(5u64), Fr::from(6u64)]; + + { + let domain = get_uni_domain::(2)?; + let l = build_l(&[point1.clone(), point2.clone()], &domain, false)?; + + let q_x = compute_w_circ_l(&w, &l, 2, false)?; + + let point: Vec = l.iter().map(|poly| poly.evaluate(&r)).collect(); + + assert_eq!( + q_x.evaluate(&r), + w.evaluate(&point).unwrap(), + "q(r) != w(l(r))" + ); + } + + { + let domain = get_uni_domain::(3)?; + + let l = build_l(&[point1, point2, point3], &domain, false)?; + let q_x = compute_w_circ_l(&w, &l, 3, false)?; + + let point: Vec = vec![l[0].evaluate(&r), l[1].evaluate(&r)]; + + assert_eq!( + q_x.evaluate(&r), + w.evaluate(&point).unwrap(), + "q(r) != w(l(r))" + ); + } + Ok(()) + } + + #[test] + fn test_qx_with_prefix() -> Result<(), PCSError> { + // Example from page 53: + // W1 = 3x1x2 + 2x2 + let w_eval = vec![Fr::zero(), Fr::from(2u64), Fr::zero(), Fr::from(5u64)]; + let w1 = Rc::new(DenseMultilinearExtension::from_evaluations_vec(2, w_eval)); + + // W2 = x1x2 + x1 + let w_eval = vec![Fr::zero(), Fr::zero(), Fr::from(1u64), Fr::from(2u64)]; + let w2 = Rc::new(DenseMultilinearExtension::from_evaluations_vec(2, w_eval)); + + // W3 = x1 + x2 + let w_eval = vec![Fr::zero(), Fr::one(), Fr::from(1u64), Fr::from(2u64)]; + let w3 = Rc::new(DenseMultilinearExtension::from_evaluations_vec(2, w_eval)); + + let r = Fr::from(42u64); + + // point 1 is [1, 2] + let point1 = vec![Fr::from(1u64), Fr::from(2u64)]; + + // point 2 is [3, 4] + let point2 = vec![Fr::from(3u64), Fr::from(4u64)]; + + // point 3 is [5, 6] + let point3 = vec![Fr::from(5u64), Fr::from(6u64)]; + + { + let domain = get_uni_domain::(2)?; + // w = (3x1x2 + 2x2)(1-x0) + (x1x2 + x1)x0 + // with evaluations: [0,2,0,5,0,0,1,2] + let w = merge_polynomials(&[w1.clone(), w2.clone()])?; + + let l = build_l(&[point1.clone(), point2.clone()], &domain, true)?; + + // sage: P. = PolynomialRing(ZZ) + // sage: l0 = -1/2 * x + 1/2 + // sage: l1 = -x + 2 + // sage: l2 = -x + 3 + // sage: w = (3 * l1 * l2 + 2 * l2) * (1-l0) + (l1 * l2 + l1) * l0 + // sage: w + // x^3 - 7/2*x^2 - 7/2*x + 16 + // + // q(x) = x^3 - 7/2*x^2 - 7/2*x + 16 + let q_x = compute_w_circ_l(&w, &l, 2, true)?; + + let point: Vec = l.iter().map(|poly| poly.evaluate(&r)).collect(); + + assert_eq!( + q_x.evaluate(&r), + w.evaluate(&point).unwrap(), + "q(r) != w(l(r))" + ); + } + + { + let domain = get_uni_domain::(3)?; + let w = merge_polynomials(&[w1, w2, w3])?; + + let l = build_l(&[point1, point2, point3], &domain, true)?; + let q_x = compute_w_circ_l(&w, &l, 3, true)?; + + let point: Vec = vec![ + l[0].evaluate(&r), + l[1].evaluate(&r), + l[2].evaluate(&r), + l[3].evaluate(&r), + ]; + + assert_eq!( + q_x.evaluate(&r), + w.evaluate(&point).unwrap(), + "q(r) != w(l(r))" + ); + } + Ok(()) + } +} diff --git a/pcs/src/prelude.rs b/pcs/src/prelude.rs new file mode 100644 index 0000000..c5623d0 --- /dev/null +++ b/pcs/src/prelude.rs @@ -0,0 +1,21 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Prelude +pub use crate::{ + errors::PCSError, + multilinear_kzg::{ + srs::{MultilinearProverParam, MultilinearUniversalParams, MultilinearVerifierParam}, + util::compute_qx_degree, + MultilinearKzgBatchProof, MultilinearKzgPCS, MultilinearKzgProof, + }, + structs::Commitment, + univariate_kzg::{ + srs::{UnivariateProverParam, UnivariateUniversalParams, UnivariateVerifierParam}, + UnivariateKzgBatchProof, UnivariateKzgPCS, UnivariateKzgProof, + }, + PolynomialCommitmentScheme, StructuredReferenceString, +}; diff --git a/pcs/src/structs.rs b/pcs/src/structs.rs new file mode 100644 index 0000000..59e2a0f --- /dev/null +++ b/pcs/src/structs.rs @@ -0,0 +1,25 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +use ark_ec::PairingEngine; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; +use derivative::Derivative; + +#[derive(Derivative, CanonicalSerialize, CanonicalDeserialize)] +#[derivative( + Default(bound = ""), + Hash(bound = ""), + Clone(bound = ""), + Copy(bound = ""), + Debug(bound = ""), + PartialEq(bound = ""), + Eq(bound = "") +)] +/// A commitment is an Affine point. +pub struct Commitment( + /// the actual commitment is an affine point. + pub E::G1Affine, +); diff --git a/pcs/src/univariate_kzg/mod.rs b/pcs/src/univariate_kzg/mod.rs new file mode 100644 index 0000000..1c7faed --- /dev/null +++ b/pcs/src/univariate_kzg/mod.rs @@ -0,0 +1,438 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Main module for univariate KZG commitment scheme + +use crate::{prelude::Commitment, PCSError, PolynomialCommitmentScheme, StructuredReferenceString}; +use ark_ec::{msm::VariableBaseMSM, AffineCurve, PairingEngine, ProjectiveCurve}; +use ark_ff::PrimeField; +use ark_poly::{univariate::DensePolynomial, Polynomial, UVPolynomial}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; +use ark_std::{ + borrow::Borrow, + end_timer, format, + marker::PhantomData, + rand::{CryptoRng, RngCore}, + start_timer, + string::ToString, + vec, + vec::Vec, + One, UniformRand, Zero, +}; +#[cfg(feature = "parallel")] +use rayon::prelude::*; +use srs::{UnivariateProverParam, UnivariateUniversalParams, UnivariateVerifierParam}; +use util::parallelizable_slice_iter; + +pub(crate) mod srs; + +/// KZG Polynomial Commitment Scheme on univariate polynomial. +pub struct UnivariateKzgPCS { + #[doc(hidden)] + phantom: PhantomData, +} + +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug, PartialEq, Eq)] +/// proof of opening +pub struct UnivariateKzgProof { + /// Evaluation of quotients + pub proof: E::G1Affine, +} +/// batch proof +pub type UnivariateKzgBatchProof = Vec>; + +impl PolynomialCommitmentScheme for UnivariateKzgPCS { + // Parameters + type ProverParam = UnivariateProverParam; + type VerifierParam = UnivariateVerifierParam; + type SRS = UnivariateUniversalParams; + // Polynomial and its associated types + type Polynomial = DensePolynomial; + type Point = E::Fr; + type Evaluation = E::Fr; + // Polynomial and its associated types + type Commitment = Commitment; + type BatchCommitment = Vec; + type Proof = UnivariateKzgProof; + type BatchProof = UnivariateKzgBatchProof; + + /// Build SRS for testing. + /// + /// - For univariate polynomials, `supported_size` is the maximum degree. + /// + /// WARNING: THIS FUNCTION IS FOR TESTING PURPOSE ONLY. + /// THE OUTPUT SRS SHOULD NOT BE USED IN PRODUCTION. + fn gen_srs_for_testing( + rng: &mut R, + supported_size: usize, + ) -> Result { + Self::SRS::gen_srs_for_testing(rng, supported_size) + } + + /// Trim the universal parameters to specialize the public parameters. + /// Input `max_degree` for univariate. + /// `supported_num_vars` must be None or an error is returned. + fn trim( + srs: impl Borrow, + supported_degree: usize, + supported_num_vars: Option, + ) -> Result<(Self::ProverParam, Self::VerifierParam), PCSError> { + if supported_num_vars.is_some() { + return Err(PCSError::InvalidParameters( + "univariate should not receive a num_var param".to_string(), + )); + } + srs.borrow().trim(supported_degree) + } + + /// Generate a commitment for a polynomial + /// Note that the scheme is not hidding + fn commit( + prover_param: impl Borrow, + poly: &Self::Polynomial, + ) -> Result { + let prover_param = prover_param.borrow(); + let commit_time = + start_timer!(|| format!("Committing to polynomial of degree {} ", poly.degree())); + + if poly.degree() >= prover_param.powers_of_g.len() { + return Err(PCSError::InvalidParameters(format!( + "uni poly degree {} is larger than allowed {}", + poly.degree(), + prover_param.powers_of_g.len() + ))); + } + + let (num_leading_zeros, plain_coeffs) = skip_leading_zeros_and_convert_to_bigints(poly); + + let msm_time = start_timer!(|| "MSM to compute commitment to plaintext poly"); + let commitment = VariableBaseMSM::multi_scalar_mul( + &prover_param.powers_of_g[num_leading_zeros..], + &plain_coeffs, + ) + .into_affine(); + end_timer!(msm_time); + + end_timer!(commit_time); + Ok(Commitment(commitment)) + } + + /// Generate a commitment for a list of polynomials + fn multi_commit( + prover_param: impl Borrow, + polys: &[Self::Polynomial], + ) -> Result { + let prover_param = prover_param.borrow(); + let commit_time = start_timer!(|| format!("batch commit {} polynomials", polys.len())); + let res = parallelizable_slice_iter(polys) + .map(|poly| Self::commit(prover_param, poly)) + .collect::, PCSError>>()?; + + end_timer!(commit_time); + Ok(res) + } + + /// On input a polynomial `p` and a point `point`, outputs a proof for the + /// same. + fn open( + prover_param: impl Borrow, + polynomial: &Self::Polynomial, + point: &Self::Point, + ) -> Result<(Self::Proof, Self::Evaluation), PCSError> { + let open_time = + start_timer!(|| format!("Opening polynomial of degree {}", polynomial.degree())); + let divisor = Self::Polynomial::from_coefficients_vec(vec![-*point, E::Fr::one()]); + + let witness_time = start_timer!(|| "Computing witness polynomial"); + let witness_polynomial = polynomial / &divisor; + end_timer!(witness_time); + + let (num_leading_zeros, witness_coeffs) = + skip_leading_zeros_and_convert_to_bigints(&witness_polynomial); + + let proof = VariableBaseMSM::multi_scalar_mul( + &prover_param.borrow().powers_of_g[num_leading_zeros..], + &witness_coeffs, + ) + .into_affine(); + + let eval = polynomial.evaluate(point); + + end_timer!(open_time); + Ok((Self::Proof { proof }, eval)) + } + + /// Input a list of polynomials, and a same number of points, + /// compute a multi-opening for all the polynomials. + // This is a naive approach + // TODO: to implement the more efficient batch opening algorithm + // (e.g., the appendix C.4 in https://eprint.iacr.org/2020/1536.pdf) + fn multi_open( + prover_param: impl Borrow, + _multi_commitment: &Self::BatchCommitment, + polynomials: &[Self::Polynomial], + points: &[Self::Point], + ) -> Result<(Self::BatchProof, Vec), PCSError> { + let open_time = start_timer!(|| format!("batch opening {} polynomials", polynomials.len())); + if polynomials.len() != points.len() { + return Err(PCSError::InvalidParameters(format!( + "poly length {} is different from points length {}", + polynomials.len(), + points.len() + ))); + } + let mut batch_proof = vec![]; + let mut evals = vec![]; + for (poly, point) in polynomials.iter().zip(points.iter()) { + let (proof, eval) = Self::open(prover_param.borrow(), poly, point)?; + batch_proof.push(proof); + evals.push(eval); + } + + end_timer!(open_time); + Ok((batch_proof, evals)) + } + + /// Input a multilinear extension, and a number of points, and + /// a transcript, compute a multi-opening for all the polynomials. + fn multi_open_single_poly( + _prover_param: impl Borrow, + _commitment: &Self::Commitment, + _polynomials: &Self::Polynomial, + _points: &[Self::Point], + ) -> Result<(Self::BatchProof, Vec), PCSError> { + unimplemented!() + } + + /// Verifies that `value` is the evaluation at `x` of the polynomial + /// committed inside `comm`. + fn verify( + verifier_param: &Self::VerifierParam, + commitment: &Self::Commitment, + point: &Self::Point, + value: &E::Fr, + proof: &Self::Proof, + ) -> Result { + let check_time = start_timer!(|| "Checking evaluation"); + let pairing_inputs: Vec<(E::G1Prepared, E::G2Prepared)> = vec![ + ( + (verifier_param.g.mul(value.into_repr()) + - proof.proof.mul(point.into_repr()) + - commitment.0.into_projective()) + .into_affine() + .into(), + verifier_param.h.into(), + ), + (proof.proof.into(), verifier_param.beta_h.into()), + ]; + + let res = E::product_of_pairings(pairing_inputs.iter()).is_one(); + + end_timer!(check_time, || format!("Result: {}", res)); + Ok(res) + } + + /// Verifies that `value_i` is the evaluation at `x_i` of the polynomial + /// `poly_i` committed inside `comm`. + // This is a naive approach + // TODO: to implement the more efficient batch verification algorithm + // (e.g., the appendix C.4 in https://eprint.iacr.org/2020/1536.pdf) + fn batch_verify( + verifier_param: &Self::VerifierParam, + multi_commitment: &Self::BatchCommitment, + points: &[Self::Point], + values: &[E::Fr], + batch_proof: &Self::BatchProof, + rng: &mut R, + ) -> Result { + let check_time = + start_timer!(|| format!("Checking {} evaluation proofs", multi_commitment.len())); + + let mut total_c = ::zero(); + let mut total_w = ::zero(); + + let combination_time = start_timer!(|| "Combining commitments and proofs"); + let mut randomizer = E::Fr::one(); + // Instead of multiplying g and gamma_g in each turn, we simply accumulate + // their coefficients and perform a final multiplication at the end. + let mut g_multiplier = E::Fr::zero(); + for (((c, z), v), proof) in multi_commitment + .iter() + .zip(points) + .zip(values) + .zip(batch_proof) + { + let w = proof.proof; + let mut temp = w.mul(*z); + temp.add_assign_mixed(&c.0); + let c = temp; + g_multiplier += &(randomizer * v); + total_c += &c.mul(randomizer.into_repr()); + total_w += &w.mul(randomizer.into_repr()); + // We don't need to sample randomizers from the full field, + // only from 128-bit strings. + randomizer = u128::rand(rng).into(); + } + total_c -= &verifier_param.g.mul(g_multiplier); + end_timer!(combination_time); + + let to_affine_time = start_timer!(|| "Converting results to affine for pairing"); + let affine_points = E::G1Projective::batch_normalization_into_affine(&[-total_w, total_c]); + let (total_w, total_c) = (affine_points[0], affine_points[1]); + end_timer!(to_affine_time); + + let pairing_time = start_timer!(|| "Performing product of pairings"); + let result = E::product_of_pairings(&[ + (total_w.into(), verifier_param.beta_h.into()), + (total_c.into(), verifier_param.h.into()), + ]) + .is_one(); + end_timer!(pairing_time); + end_timer!(check_time, || format!("Result: {}", result)); + Ok(result) + } + + /// Verifies that `value_i` is the evaluation at `x_i` of the polynomial + /// `poly` committed inside `comm`. + fn batch_verify_single_poly( + _verifier_param: &Self::VerifierParam, + _commitment: &Self::Commitment, + _points: &[Self::Point], + _values: &[E::Fr], + _batch_proof: &Self::BatchProof, + ) -> Result { + unimplemented!() + } +} + +fn skip_leading_zeros_and_convert_to_bigints>( + p: &P, +) -> (usize, Vec) { + let mut num_leading_zeros = 0; + while num_leading_zeros < p.coeffs().len() && p.coeffs()[num_leading_zeros].is_zero() { + num_leading_zeros += 1; + } + let coeffs = convert_to_bigints(&p.coeffs()[num_leading_zeros..]); + (num_leading_zeros, coeffs) +} + +fn convert_to_bigints(p: &[F]) -> Vec { + let to_bigint_time = start_timer!(|| "Converting polynomial coeffs to bigints"); + let coeffs = p.iter().map(|s| s.into_repr()).collect::>(); + end_timer!(to_bigint_time); + coeffs +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::StructuredReferenceString; + use ark_bls12_381::Bls12_381; + use ark_ec::PairingEngine; + use ark_poly::univariate::DensePolynomial; + use ark_std::{test_rng, UniformRand}; + + fn end_to_end_test_template() -> Result<(), PCSError> + where + E: PairingEngine, + { + let rng = &mut test_rng(); + for _ in 0..100 { + let mut degree = 0; + while degree <= 1 { + degree = usize::rand(rng) % 20; + } + let pp = UnivariateKzgPCS::::gen_srs_for_testing(rng, degree)?; + let (ck, vk) = pp.trim(degree)?; + let p = as UVPolynomial>::rand(degree, rng); + let comm = UnivariateKzgPCS::::commit(&ck, &p)?; + let point = E::Fr::rand(rng); + let (proof, value) = UnivariateKzgPCS::::open(&ck, &p, &point)?; + assert!( + UnivariateKzgPCS::::verify(&vk, &comm, &point, &value, &proof)?, + "proof was incorrect for max_degree = {}, polynomial_degree = {}", + degree, + p.degree(), + ); + } + Ok(()) + } + + fn linear_polynomial_test_template() -> Result<(), PCSError> + where + E: PairingEngine, + { + let rng = &mut test_rng(); + for _ in 0..100 { + let degree = 50; + + let pp = UnivariateKzgPCS::::gen_srs_for_testing(rng, degree)?; + let (ck, vk) = pp.trim(degree)?; + let p = as UVPolynomial>::rand(degree, rng); + let comm = UnivariateKzgPCS::::commit(&ck, &p)?; + let point = E::Fr::rand(rng); + let (proof, value) = UnivariateKzgPCS::::open(&ck, &p, &point)?; + assert!( + UnivariateKzgPCS::::verify(&vk, &comm, &point, &value, &proof)?, + "proof was incorrect for max_degree = {}, polynomial_degree = {}", + degree, + p.degree(), + ); + } + Ok(()) + } + + fn batch_check_test_template() -> Result<(), PCSError> + where + E: PairingEngine, + { + let rng = &mut test_rng(); + for _ in 0..10 { + let mut degree = 0; + while degree <= 1 { + degree = usize::rand(rng) % 20; + } + let pp = UnivariateKzgPCS::::gen_srs_for_testing(rng, degree)?; + let (ck, vk) = UnivariateKzgPCS::::trim(&pp, degree, None)?; + let mut comms = Vec::new(); + let mut values = Vec::new(); + let mut points = Vec::new(); + let mut proofs = Vec::new(); + for _ in 0..10 { + let p = as UVPolynomial>::rand(degree, rng); + let comm = UnivariateKzgPCS::::commit(&ck, &p)?; + let point = E::Fr::rand(rng); + let (proof, value) = UnivariateKzgPCS::::open(&ck, &p, &point)?; + + assert!(UnivariateKzgPCS::::verify( + &vk, &comm, &point, &value, &proof + )?); + comms.push(comm); + values.push(value); + points.push(point); + proofs.push(proof); + } + assert!(UnivariateKzgPCS::::batch_verify( + &vk, &comms, &points, &values, &proofs, rng + )?); + } + Ok(()) + } + + #[test] + fn end_to_end_test() { + end_to_end_test_template::().expect("test failed for bls12-381"); + } + + #[test] + fn linear_polynomial_test() { + linear_polynomial_test_template::().expect("test failed for bls12-381"); + } + #[test] + fn batch_check_test() { + batch_check_test_template::().expect("test failed for bls12-381"); + } +} diff --git a/pcs/src/univariate_kzg/srs.rs b/pcs/src/univariate_kzg/srs.rs new file mode 100644 index 0000000..933e0c4 --- /dev/null +++ b/pcs/src/univariate_kzg/srs.rs @@ -0,0 +1,156 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Implementing Structured Reference Strings for univariate polynomial KZG + +use crate::{PCSError, StructuredReferenceString}; +use ark_ec::{msm::FixedBaseMSM, AffineCurve, PairingEngine, ProjectiveCurve}; +use ark_ff::PrimeField; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; +use ark_std::{ + end_timer, + rand::{CryptoRng, RngCore}, + start_timer, vec, + vec::Vec, + One, UniformRand, +}; +use derivative::Derivative; + +/// `UniversalParams` are the universal parameters for the KZG10 scheme. +// Adapted from +// https://github.com/arkworks-rs/poly-commit/blob/master/src/kzg10/data_structures.rs#L20 +#[derive(Debug, Clone, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize, Default)] +pub struct UnivariateUniversalParams { + /// Group elements of the form `{ \beta^i G }`, where `i` ranges from 0 to + /// `degree`. + pub powers_of_g: Vec, + /// The generator of G2. + pub h: E::G2Affine, + /// \beta times the above generator of G2. + pub beta_h: E::G2Affine, +} + +impl UnivariateUniversalParams { + /// Returns the maximum supported degree + pub fn max_degree(&self) -> usize { + self.powers_of_g.len() + } +} + +/// `UnivariateProverParam` is used to generate a proof +#[derive(CanonicalSerialize, CanonicalDeserialize, Clone, Debug, Eq, PartialEq, Default)] +pub struct UnivariateProverParam { + /// Parameters + pub powers_of_g: Vec, +} + +/// `UnivariateVerifierParam` is used to check evaluation proofs for a given +/// commitment. +#[derive(Derivative, CanonicalSerialize, CanonicalDeserialize)] +#[derivative( + Default(bound = ""), + Clone(bound = ""), + Copy(bound = ""), + Debug(bound = ""), + PartialEq(bound = ""), + Eq(bound = "") +)] +pub struct UnivariateVerifierParam { + /// The generator of G1. + pub g: E::G1Affine, + /// The generator of G2. + pub h: E::G2Affine, + /// \beta times the above generator of G2. + pub beta_h: E::G2Affine, +} + +impl StructuredReferenceString for UnivariateUniversalParams { + type ProverParam = UnivariateProverParam; + type VerifierParam = UnivariateVerifierParam; + + /// Extract the prover parameters from the public parameters. + fn extract_prover_param(&self, supported_size: usize) -> Self::ProverParam { + let powers_of_g = self.powers_of_g[..=supported_size].to_vec(); + + Self::ProverParam { powers_of_g } + } + + /// Extract the verifier parameters from the public parameters. + fn extract_verifier_param(&self, _supported_size: usize) -> Self::VerifierParam { + Self::VerifierParam { + g: self.powers_of_g[0], + h: self.h, + beta_h: self.beta_h, + } + } + + /// Trim the universal parameters to specialize the public parameters + /// for univariate polynomials to the given `supported_size`, and + /// returns committer key and verifier key. `supported_size` should + /// be in range `1..params.len()` + fn trim( + &self, + supported_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), PCSError> { + let powers_of_g = self.powers_of_g[..=supported_size].to_vec(); + + let pk = Self::ProverParam { powers_of_g }; + let vk = Self::VerifierParam { + g: self.powers_of_g[0], + h: self.h, + beta_h: self.beta_h, + }; + Ok((pk, vk)) + } + + /// Build SRS for testing. + /// WARNING: THIS FUNCTION IS FOR TESTING PURPOSE ONLY. + /// THE OUTPUT SRS SHOULD NOT BE USED IN PRODUCTION. + fn gen_srs_for_testing( + rng: &mut R, + max_degree: usize, + ) -> Result { + let setup_time = start_timer!(|| format!("KZG10::Setup with degree {}", max_degree)); + let beta = E::Fr::rand(rng); + let g = E::G1Projective::rand(rng); + let h = E::G2Projective::rand(rng); + + let mut powers_of_beta = vec![E::Fr::one()]; + + let mut cur = beta; + for _ in 0..max_degree { + powers_of_beta.push(cur); + cur *= β + } + + let window_size = FixedBaseMSM::get_mul_window_size(max_degree + 1); + + let scalar_bits = E::Fr::size_in_bits(); + let g_time = start_timer!(|| "Generating powers of G"); + // TODO: parallelization + let g_table = FixedBaseMSM::get_window_table(scalar_bits, window_size, g); + let powers_of_g = FixedBaseMSM::multi_scalar_mul::( + scalar_bits, + window_size, + &g_table, + &powers_of_beta, + ); + end_timer!(g_time); + + let powers_of_g = E::G1Projective::batch_normalization_into_affine(&powers_of_g); + + let h = h.into_affine(); + let beta_h = h.mul(beta).into_affine(); + + let pp = Self { + powers_of_g, + h, + beta_h, + }; + end_timer!(setup_time); + Ok(pp) + } +} diff --git a/poly-iop/Cargo.toml b/poly-iop/Cargo.toml index 2fdac37..8f9252d 100644 --- a/poly-iop/Cargo.toml +++ b/poly-iop/Cargo.toml @@ -20,8 +20,8 @@ rayon = { version = "1.5.2", default-features = false, optional = true } transcript = { path = "../transcript" } arithmetic = { path = "../arithmetic" } - -jf-primitives = { git = "https://github.com/EspressoSystems/jellyfish", rev = "ff43209" } +pcs = { path = "../pcs" } +util = { path = "../util" } [dev-dependencies] ark-ec = { version = "^0.3.0", default-features = false } @@ -35,16 +35,20 @@ harness = false [features] # default = [ "parallel", "print-trace" ] default = [ "parallel" ] +# extensive sanity checks that are useful for debugging +extensive_sanity_checks = [ + "pcs/extensive_sanity_checks", + ] parallel = [ "rayon", "arithmetic/parallel", "ark-std/parallel", "ark-ff/parallel", "ark-poly/parallel", - "jf-primitives/parallel", + "pcs/parallel", + "util/parallel" ] print-trace = [ "arithmetic/print-trace", "ark-std/print-trace", - "jf-primitives/print-trace", ] \ No newline at end of file diff --git a/poly-iop/benches/bench.rs b/poly-iop/benches/bench.rs index abee327..e7e6061 100644 --- a/poly-iop/benches/bench.rs +++ b/poly-iop/benches/bench.rs @@ -1,15 +1,14 @@ -use arithmetic::{VPAuxInfo, VirtualPolynomial}; +use arithmetic::{identity_permutation_mle, VPAuxInfo, VirtualPolynomial}; use ark_bls12_381::{Bls12_381, Fr}; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_std::test_rng; -use jf_primitives::pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme}; +use pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme}; use poly_iop::prelude::{ - identity_permutation_mle, PermutationCheck, PolyIOP, PolyIOPErrors, ProductCheck, SumCheck, - ZeroCheck, + PermutationCheck, PolyIOP, PolyIOPErrors, ProductCheck, SumCheck, ZeroCheck, }; use std::{marker::PhantomData, rc::Rc, time::Instant}; -type Kzg = MultilinearKzgPCS; +type KZG = MultilinearKzgPCS; fn main() -> Result<(), PolyIOPErrors> { bench_permutation_check()?; @@ -140,8 +139,8 @@ fn bench_permutation_check() -> Result<(), PolyIOPErrors> { let mut rng = test_rng(); for nv in 4..20 { - let srs = Kzg::gen_srs_for_testing(&mut rng, nv + 1)?; - let (pcs_param, _) = Kzg::trim(&srs, nv + 1, Some(nv + 1))?; + let srs = KZG::gen_srs_for_testing(&mut rng, nv + 1)?; + let (pcs_param, _) = KZG::trim(&srs, nv + 1, Some(nv + 1))?; let repetition = if nv < 10 { 100 @@ -159,10 +158,10 @@ fn bench_permutation_check() -> Result<(), PolyIOPErrors> { let proof = { let start = Instant::now(); let mut transcript = - as PermutationCheck>::init_transcript(); + as PermutationCheck>::init_transcript(); transcript.append_message(b"testing", b"initializing transcript for testing")?; - let (proof, _q_x) = as PermutationCheck>::prove( + let (proof, _q_x) = as PermutationCheck>::prove( &pcs_param, &w, &w, @@ -187,9 +186,9 @@ fn bench_permutation_check() -> Result<(), PolyIOPErrors> { let start = Instant::now(); let mut transcript = - as PermutationCheck>::init_transcript(); + as PermutationCheck>::init_transcript(); transcript.append_message(b"testing", b"initializing transcript for testing")?; - let _perm_check_sum_claim = as PermutationCheck>::verify( + let _perm_check_sum_claim = as PermutationCheck>::verify( &proof, &poly_info, &mut transcript, @@ -211,8 +210,8 @@ fn bench_prod_check() -> Result<(), PolyIOPErrors> { let mut rng = test_rng(); for nv in 4..20 { - let srs = Kzg::gen_srs_for_testing(&mut rng, nv + 1)?; - let (pcs_param, _) = Kzg::trim(&srs, nv + 1, Some(nv + 1))?; + let srs = KZG::gen_srs_for_testing(&mut rng, nv + 1)?; + let (pcs_param, _) = KZG::trim(&srs, nv + 1, Some(nv + 1))?; let repetition = if nv < 10 { 100 @@ -230,10 +229,10 @@ fn bench_prod_check() -> Result<(), PolyIOPErrors> { let proof = { let start = Instant::now(); - let mut transcript = as ProductCheck>::init_transcript(); + let mut transcript = as ProductCheck>::init_transcript(); transcript.append_message(b"testing", b"initializing transcript for testing")?; - let (proof, _prod_x) = as ProductCheck>::prove( + let (proof, _prod_x) = as ProductCheck>::prove( &pcs_param, &f, &g, @@ -256,9 +255,9 @@ fn bench_prod_check() -> Result<(), PolyIOPErrors> { }; let start = Instant::now(); - let mut transcript = as ProductCheck>::init_transcript(); + let mut transcript = as ProductCheck>::init_transcript(); transcript.append_message(b"testing", b"initializing transcript for testing")?; - let _perm_check_sum_claim = as ProductCheck>::verify( + let _perm_check_sum_claim = as ProductCheck>::verify( &proof, &poly_info, &mut transcript, diff --git a/poly-iop/readme.md b/poly-iop/readme.md index 86cd654..d3759a4 100644 --- a/poly-iop/readme.md +++ b/poly-iop/readme.md @@ -3,6 +3,7 @@ Poly IOP Implements the following protocols -- [x] sum checks -- [x] zero checks -- [x] permutation checks \ No newline at end of file +- sum checks +- zero checks +- product checks +- permutation checks \ No newline at end of file diff --git a/poly-iop/src/errors.rs b/poly-iop/src/errors.rs index 8ac4d4f..eb5cfc0 100644 --- a/poly-iop/src/errors.rs +++ b/poly-iop/src/errors.rs @@ -3,8 +3,8 @@ use arithmetic::ArithErrors; use ark_std::string::String; use displaydoc::Display; -use jf_primitives::pcs::errors::PCSError; -use transcript::TranscriptErrors; +use pcs::prelude::PCSError; +use transcript::TranscriptError; /// A `enum` specifying the possible failure modes of the PolyIOP. #[derive(Display, Debug)] @@ -24,11 +24,11 @@ pub enum PolyIOPErrors { /// An error during (de)serialization: {0} SerializationErrors(ark_serialize::SerializationError), /// Transcript Error: {0} - TranscriptErrors(TranscriptErrors), + TranscriptErrors(TranscriptError), /// Arithmetic Error: {0} ArithmeticErrors(ArithErrors), /// PCS error {0} - PCSError(PCSError), + PCSErrors(PCSError), } impl From for PolyIOPErrors { @@ -37,8 +37,8 @@ impl From for PolyIOPErrors { } } -impl From for PolyIOPErrors { - fn from(e: TranscriptErrors) -> Self { +impl From for PolyIOPErrors { + fn from(e: TranscriptError) -> Self { Self::TranscriptErrors(e) } } @@ -51,6 +51,6 @@ impl From for PolyIOPErrors { impl From for PolyIOPErrors { fn from(e: PCSError) -> Self { - Self::PCSError(e) + Self::PCSErrors(e) } } diff --git a/poly-iop/src/perm_check/mod.rs b/poly-iop/src/perm_check/mod.rs index a9a7f78..1c05479 100644 --- a/poly-iop/src/perm_check/mod.rs +++ b/poly-iop/src/perm_check/mod.rs @@ -5,7 +5,7 @@ use crate::{errors::PolyIOPErrors, prelude::ProductCheck, PolyIOP}; use ark_ec::PairingEngine; use ark_poly::DenseMultilinearExtension; use ark_std::{end_timer, start_timer}; -use jf_primitives::pcs::PolynomialCommitmentScheme; +use pcs::PolynomialCommitmentScheme; use std::rc::Rc; use transcript::IOPTranscript; @@ -153,20 +153,16 @@ where #[cfg(test)] mod test { use super::PermutationCheck; - use crate::{ - errors::PolyIOPErrors, - prelude::{identity_permutation_mle, random_permutation_mle}, - PolyIOP, - }; - use arithmetic::VPAuxInfo; + use crate::{errors::PolyIOPErrors, PolyIOP}; + use arithmetic::{evaluate_opt, identity_permutation_mle, random_permutation_mle, VPAuxInfo}; use ark_bls12_381::Bls12_381; use ark_ec::PairingEngine; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_std::test_rng; - use jf_primitives::pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme}; + use pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme}; use std::{marker::PhantomData, rc::Rc}; - type Kzg = MultilinearKzgPCS; + type KZG = MultilinearKzgPCS; fn test_permutation_check_helper( pcs_param: &PCS::ProverParam, @@ -206,10 +202,10 @@ mod test { )?; // check product subclaim - if prod_x - .evaluate(&perm_check_sub_claim.product_check_sub_claim.final_query.0) - .unwrap() - != perm_check_sub_claim.product_check_sub_claim.final_query.1 + if evaluate_opt( + &prod_x, + &perm_check_sub_claim.product_check_sub_claim.final_query.0, + ) != perm_check_sub_claim.product_check_sub_claim.final_query.1 { return Err(PolyIOPErrors::InvalidVerifier("wrong subclaim".to_string())); }; @@ -228,7 +224,7 @@ mod test { let w = Rc::new(DenseMultilinearExtension::rand(nv, &mut rng)); // s_perm is the identity map let s_perm = identity_permutation_mle(nv); - test_permutation_check_helper::(&pcs_param, &w, &w, &s_perm)?; + test_permutation_check_helper::(&pcs_param, &w, &w, &s_perm)?; } { @@ -238,9 +234,9 @@ mod test { let s_perm = random_permutation_mle(nv, &mut rng); if nv == 1 { - test_permutation_check_helper::(&pcs_param, &w, &w, &s_perm)?; + test_permutation_check_helper::(&pcs_param, &w, &w, &s_perm)?; } else { - assert!(test_permutation_check_helper::( + assert!(test_permutation_check_helper::( &pcs_param, &w, &w, &s_perm ) .is_err()); @@ -255,7 +251,7 @@ mod test { let s_perm = identity_permutation_mle(nv); assert!( - test_permutation_check_helper::(&pcs_param, &f, &g, &s_perm) + test_permutation_check_helper::(&pcs_param, &f, &g, &s_perm) .is_err() ); } diff --git a/poly-iop/src/perm_check/util.rs b/poly-iop/src/perm_check/util.rs index cec0f4d..2098174 100644 --- a/poly-iop/src/perm_check/util.rs +++ b/poly-iop/src/perm_check/util.rs @@ -1,9 +1,10 @@ //! This module implements useful functions for the permutation check protocol. use crate::errors::PolyIOPErrors; +use arithmetic::identity_permutation_mle; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; -use ark_std::{end_timer, rand::RngCore, start_timer}; +use ark_std::{end_timer, start_timer}; use std::rc::Rc; /// Returns the evaluations of two MLEs: @@ -58,30 +59,3 @@ pub(super) fn computer_num_and_denom( end_timer!(start); Ok((numerator, denominator)) } - -/// An MLE that represent an identity permutation: `f(index) \mapto index` -pub fn identity_permutation_mle( - num_vars: usize, -) -> Rc> { - let s_id_vec = (0..1u64 << num_vars).map(F::from).collect(); - Rc::new(DenseMultilinearExtension::from_evaluations_vec( - num_vars, s_id_vec, - )) -} - -/// An MLE that represent a random permutation -pub fn random_permutation_mle( - num_vars: usize, - rng: &mut R, -) -> Rc> { - let len = 1u64 << num_vars; - let mut s_id_vec: Vec = (0..len).map(F::from).collect(); - let mut s_perm_vec = vec![]; - for _ in 0..len { - let index = rng.next_u64() as usize % s_id_vec.len(); - s_perm_vec.push(s_id_vec.remove(index)); - } - Rc::new(DenseMultilinearExtension::from_evaluations_vec( - num_vars, s_perm_vec, - )) -} diff --git a/poly-iop/src/prelude.rs b/poly-iop/src/prelude.rs index 634a196..36ebba1 100644 --- a/poly-iop/src/prelude.rs +++ b/poly-iop/src/prelude.rs @@ -1,12 +1,4 @@ pub use crate::{ - errors::PolyIOPErrors, - perm_check::{ - util::{identity_permutation_mle, random_permutation_mle}, - PermutationCheck, - }, - prod_check::ProductCheck, - sum_check::SumCheck, - utils::*, - zero_check::ZeroCheck, - PolyIOP, + errors::PolyIOPErrors, perm_check::PermutationCheck, prod_check::ProductCheck, + sum_check::SumCheck, utils::*, zero_check::ZeroCheck, PolyIOP, }; diff --git a/poly-iop/src/prod_check/mod.rs b/poly-iop/src/prod_check/mod.rs index 2bd1279..73440c9 100644 --- a/poly-iop/src/prod_check/mod.rs +++ b/poly-iop/src/prod_check/mod.rs @@ -11,7 +11,7 @@ use ark_ec::PairingEngine; use ark_ff::{One, PrimeField, Zero}; use ark_poly::DenseMultilinearExtension; use ark_std::{end_timer, start_timer}; -use jf_primitives::pcs::prelude::PolynomialCommitmentScheme; +use pcs::PolynomialCommitmentScheme; use std::rc::Rc; use transcript::IOPTranscript; @@ -207,7 +207,7 @@ mod test { use ark_ec::PairingEngine; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_std::test_rng; - use jf_primitives::pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme}; + use pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme}; use std::{marker::PhantomData, rc::Rc}; // f and g are guaranteed to have the same product diff --git a/poly-iop/src/prod_check/util.rs b/poly-iop/src/prod_check/util.rs index cb645f7..5a67f5c 100644 --- a/poly-iop/src/prod_check/util.rs +++ b/poly-iop/src/prod_check/util.rs @@ -1,12 +1,11 @@ //! This module implements useful functions for the product check protocol. -use crate::{ - errors::PolyIOPErrors, structs::IOPProof, utils::get_index, zero_check::ZeroCheck, PolyIOP, -}; -use arithmetic::VirtualPolynomial; +use crate::{errors::PolyIOPErrors, structs::IOPProof, zero_check::ZeroCheck, PolyIOP}; +use arithmetic::{get_index, VirtualPolynomial}; use ark_ff::PrimeField; use ark_poly::DenseMultilinearExtension; use ark_std::{end_timer, start_timer}; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use std::rc::Rc; use transcript::IOPTranscript; @@ -191,10 +190,12 @@ fn compute_prod_0( ) -> Result, PolyIOPErrors> { let start = start_timer!(|| "compute prod(0,x)"); - let mut prod_0x_evals = vec![]; - for (&fi, &gi) in fx.iter().zip(gx.iter()) { - prod_0x_evals.push(fi / gi); - } + let input = fx + .iter() + .zip(gx.iter()) + .map(|(&fi, &gi)| (fi, gi)) + .collect::>(); + let prod_0x_evals = input.par_iter().map(|(x, y)| *x / *y).collect::>(); end_timer!(start); Ok(prod_0x_evals) diff --git a/poly-iop/src/sum_check/prover.rs b/poly-iop/src/sum_check/prover.rs index 93d2022..d8793b1 100644 --- a/poly-iop/src/sum_check/prover.rs +++ b/poly-iop/src/sum_check/prover.rs @@ -5,14 +5,15 @@ use crate::{ errors::PolyIOPErrors, structs::{IOPProverMessage, IOPProverState}, }; -use arithmetic::VirtualPolynomial; +use arithmetic::{fix_variables, VirtualPolynomial}; use ark_ff::PrimeField; -use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; +use ark_poly::DenseMultilinearExtension; use ark_std::{end_timer, start_timer, vec::Vec}; +use rayon::prelude::IntoParallelIterator; use std::rc::Rc; #[cfg(feature = "parallel")] -use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; impl SumCheckProver for IOPProverState { type VirtualPolynomial = VirtualPolynomial; @@ -44,8 +45,9 @@ impl SumCheckProver for IOPProverState { &mut self, challenge: &Option, ) -> Result { - let start = - start_timer!(|| format!("sum check prove {}-th round and update state", self.round)); + // let start = + // start_timer!(|| format!("sum check prove {}-th round and update state", + // self.round)); if self.round >= self.poly.aux_info.num_variables { return Err(PolyIOPErrors::InvalidProver( @@ -53,7 +55,7 @@ impl SumCheckProver for IOPProverState { )); } - let fix_argument = start_timer!(|| "fix argument"); + // let fix_argument = start_timer!(|| "fix argument"); // Step 1: // fix argument and evaluate f(x) over x_m = r; where r is the challenge @@ -85,18 +87,18 @@ impl SumCheckProver for IOPProverState { #[cfg(feature = "parallel")] flattened_ml_extensions .par_iter_mut() - .for_each(|mle| *mle = mle.fix_variables(&[r])); + .for_each(|mle| *mle = fix_variables(mle, &[r])); #[cfg(not(feature = "parallel"))] flattened_ml_extensions .iter_mut() - .for_each(|mle| *mle = mle.fix_variables(&[r])); + .for_each(|mle| *mle = fix_variables(mle, &[r])); } else if self.round > 0 { return Err(PolyIOPErrors::InvalidProver( "verifier message is empty".to_string(), )); } - end_timer!(fix_argument); + // end_timer!(fix_argument); self.round += 1; @@ -104,29 +106,44 @@ impl SumCheckProver for IOPProverState { let mut products_sum = Vec::with_capacity(self.poly.aux_info.max_degree + 1); products_sum.resize(self.poly.aux_info.max_degree + 1, F::zero()); - let compute_sum = start_timer!(|| "compute sum"); + // let compute_sum = start_timer!(|| "compute sum"); + // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) #[cfg(feature = "parallel")] - products_sum.par_iter_mut().enumerate().for_each(|(t, e)| { - for b in 0..1 << (self.poly.aux_info.num_variables - self.round) { - // evaluate P_round(t) - for (coefficient, products) in products_list.iter() { - let num_mles = products.len(); - let mut product = *coefficient; - for &f in products.iter().take(num_mles) { - let table = &flattened_ml_extensions[f]; // f's range is checked in init - product *= table[b << 1] * (F::one() - F::from(t as u64)) - + table[(b << 1) + 1] * F::from(t as u64); - } - *e += product; - } + for (t, e) in products_sum.iter_mut().enumerate() { + let t = F::from(t as u64); + let one_minus_t = F::one() - t; + let products = (0..1 << (self.poly.aux_info.num_variables - self.round)) + .into_par_iter() + .map(|b| { + // evaluate P_round(t) + let mut tmp = F::zero(); + products_list.iter().for_each(|(coefficient, products)| { + let num_mles = products.len(); + let mut product = *coefficient; + for &f in products.iter().take(num_mles) { + let table = &flattened_ml_extensions[f]; // f's range is checked in init + product *= table[b << 1] * one_minus_t + table[(b << 1) + 1] * t; + } + tmp += product; + }); + + tmp + }) + .collect::>(); + + for i in products.iter() { + *e += i } - }); + } #[cfg(not(feature = "parallel"))] products_sum.iter_mut().enumerate().for_each(|(t, e)| { + let t = F::from(t as u64); + let one_minus_t = F::one() - t; + for b in 0..1 << (self.poly.aux_info.num_variables - self.round) { // evaluate P_round(t) for (coefficient, products) in products_list.iter() { @@ -134,8 +151,7 @@ impl SumCheckProver for IOPProverState { let mut product = *coefficient; for &f in products.iter().take(num_mles) { let table = &flattened_ml_extensions[f]; // f's range is checked in init - product *= table[b << 1] * (F::one() - F::from(t as u64)) - + table[(b << 1) + 1] * F::from(t as u64); + product *= table[b << 1] * one_minus_t + table[(b << 1) + 1] * t; } *e += product; } @@ -148,8 +164,8 @@ impl SumCheckProver for IOPProverState { .map(|x| Rc::new(x.clone())) .collect(); - end_timer!(compute_sum); - end_timer!(start); + // end_timer!(compute_sum); + // end_timer!(start); Ok(IOPProverMessage { evaluations: products_sum, }) diff --git a/poly-iop/src/utils.rs b/poly-iop/src/utils.rs index 8d43b8b..a5851ae 100644 --- a/poly-iop/src/utils.rs +++ b/poly-iop/src/utils.rs @@ -11,53 +11,11 @@ macro_rules! to_bytes { }}; } -/// Decompose an integer into a binary vector in little endian. -#[allow(dead_code)] -pub fn bit_decompose(input: u64, num_var: usize) -> Vec { - let mut res = Vec::with_capacity(num_var); - let mut i = input; - for _ in 0..num_var { - res.push(i & 1 == 1); - i >>= 1; - } - res -} - -/// Project a little endian binary vector into an integer. -#[allow(dead_code)] -pub(crate) fn project(input: &[bool]) -> u64 { - let mut res = 0; - for &e in input.iter().rev() { - res <<= 1; - res += e as u64; - } - res -} - -// Input index -// - `i := (i_0, ...i_{n-1})`, -// - `num_vars := n` -// return three elements: -// - `x0 := (i_1, ..., i_{n-1}, 0)` -// - `x1 := (i_1, ..., i_{n-1}, 1)` -// - `sign := i_0` -#[inline] -pub(crate) fn get_index(i: usize, num_vars: usize) -> (usize, usize, bool) { - let bit_sequence = bit_decompose(i as u64, num_vars); - - // the last bit comes first here because of LE encoding - let x0 = project(&[[false].as_ref(), bit_sequence[..num_vars - 1].as_ref()].concat()) as usize; - let x1 = project(&[[true].as_ref(), bit_sequence[..num_vars - 1].as_ref()].concat()) as usize; - - (x0, x1, bit_sequence[num_vars - 1]) -} - #[cfg(test)] mod test { - use super::{bit_decompose, get_index, project}; use ark_bls12_381::Fr; use ark_serialize::CanonicalSerialize; - use ark_std::{rand::RngCore, test_rng, One}; + use ark_std::One; #[test] fn test_to_bytes() { @@ -67,35 +25,4 @@ mod test { f1.serialize(&mut bytes).unwrap(); assert_eq!(bytes, to_bytes!(&f1).unwrap()); } - - #[test] - fn test_decomposition() { - let mut rng = test_rng(); - for _ in 0..100 { - let t = rng.next_u64(); - let b = bit_decompose(t, 64); - let r = project(&b); - assert_eq!(t, r) - } - } - - #[test] - fn test_get_index() { - let a = 0b1010; - let (x0, x1, sign) = get_index(a, 4); - assert_eq!(x0, 0b0100); - assert_eq!(x1, 0b0101); - assert!(sign); - - let (x0, x1, sign) = get_index(a, 5); - assert_eq!(x0, 0b10100); - assert_eq!(x1, 0b10101); - assert!(!sign); - - let a = 0b1111; - let (x0, x1, sign) = get_index(a, 4); - assert_eq!(x0, 0b1110); - assert_eq!(x1, 0b1111); - assert!(sign); - } } diff --git a/transcript/src/errors.rs b/transcript/src/errors.rs index 0886575..6a29659 100644 --- a/transcript/src/errors.rs +++ b/transcript/src/errors.rs @@ -5,14 +5,14 @@ use displaydoc::Display; /// A `enum` specifying the possible failure modes of the Transcript. #[derive(Display, Debug)] -pub enum TranscriptErrors { +pub enum TranscriptError { /// Invalid Transcript: {0} InvalidTranscript(String), /// An error during (de)serialization: {0} SerializationError(ark_serialize::SerializationError), } -impl From for TranscriptErrors { +impl From for TranscriptError { fn from(e: ark_serialize::SerializationError) -> Self { Self::SerializationError(e) } diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index 4764b88..95c6dfc 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -4,7 +4,7 @@ //! TODO(ZZ): decide which APIs need to be public. mod errors; -pub use errors::TranscriptErrors; +pub use errors::TranscriptError; use ark_ff::PrimeField; use ark_serialize::CanonicalSerialize; @@ -44,7 +44,7 @@ impl IOPTranscript { &mut self, label: &'static [u8], msg: &[u8], - ) -> Result<(), TranscriptErrors> { + ) -> Result<(), TranscriptError> { self.transcript.append_message(label, msg); self.is_empty = false; Ok(()) @@ -55,7 +55,7 @@ impl IOPTranscript { &mut self, label: &'static [u8], field_elem: &F, - ) -> Result<(), TranscriptErrors> { + ) -> Result<(), TranscriptError> { self.append_message(label, &to_bytes!(field_elem)?) } @@ -64,7 +64,7 @@ impl IOPTranscript { &mut self, label: &'static [u8], group_elem: &S, - ) -> Result<(), TranscriptErrors> { + ) -> Result<(), TranscriptError> { self.append_message(label, &to_bytes!(group_elem)?) } @@ -73,13 +73,10 @@ impl IOPTranscript { // // The output field element is statistical uniform as long // as the field has a size less than 2^384. - pub fn get_and_append_challenge( - &mut self, - label: &'static [u8], - ) -> Result { + pub fn get_and_append_challenge(&mut self, label: &'static [u8]) -> Result { // we need to reject when transcript is empty if self.is_empty { - return Err(TranscriptErrors::InvalidTranscript( + return Err(TranscriptError::InvalidTranscript( "transcript is empty".to_string(), )); } @@ -100,10 +97,10 @@ impl IOPTranscript { &mut self, label: &'static [u8], len: usize, - ) -> Result, TranscriptErrors> { + ) -> Result, TranscriptError> { // we need to reject when transcript is empty if self.is_empty { - return Err(TranscriptErrors::InvalidTranscript( + return Err(TranscriptError::InvalidTranscript( "transcript is empty".to_string(), )); } diff --git a/util/Cargo.toml b/util/Cargo.toml new file mode 100644 index 0000000..45ad040 --- /dev/null +++ b/util/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "util" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rayon = { version = "1.5.0", optional = true } + +[features] +default = [] +parallel = [ "rayon" ] \ No newline at end of file diff --git a/util/src/lib.rs b/util/src/lib.rs new file mode 100644 index 0000000..c99c8f9 --- /dev/null +++ b/util/src/lib.rs @@ -0,0 +1,29 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Utilities for parallel code. + +/// this function helps with slice iterator creation that optionally use +/// `par_iter()` when feature flag `parallel` is on. +/// +/// # Usage +/// let v = [1, 2, 3, 4, 5]; +/// let sum = parallelizable_slice_iter(&v).sum(); +/// +/// // the above code is a shorthand for (thus equivalent to) +/// #[cfg(feature = "parallel")] +/// let sum = v.par_iter().sum(); +/// #[cfg(not(feature = "parallel"))] +/// let sum = v.iter().sum(); +#[cfg(feature = "parallel")] +pub fn parallelizable_slice_iter(data: &[T]) -> rayon::slice::Iter { + use rayon::iter::IntoParallelIterator; + data.into_par_iter() +} + +#[cfg(not(feature = "parallel"))] +pub fn parallelizable_slice_iter(data: &[T]) -> core::slice::Iter { + data.iter() +}