enabling batch opening and mock tests (#80)

- add mock circuits
- add vanilla and jellyfish plonk gates
- performance tuning
This commit is contained in:
zhenfei
2022-09-27 14:51:30 -04:00
committed by GitHub
parent 3160ef17f2
commit baaa06b07b
51 changed files with 5637 additions and 1388 deletions

View File

@@ -3,8 +3,8 @@
use arithmetic::ArithErrors;
use ark_std::string::String;
use displaydoc::Display;
use jf_primitives::pcs::errors::PCSError;
use transcript::TranscriptErrors;
use pcs::prelude::PCSError;
use transcript::TranscriptError;
/// A `enum` specifying the possible failure modes of the PolyIOP.
#[derive(Display, Debug)]
@@ -24,11 +24,11 @@ pub enum PolyIOPErrors {
/// An error during (de)serialization: {0}
SerializationErrors(ark_serialize::SerializationError),
/// Transcript Error: {0}
TranscriptErrors(TranscriptErrors),
TranscriptErrors(TranscriptError),
/// Arithmetic Error: {0}
ArithmeticErrors(ArithErrors),
/// PCS error {0}
PCSError(PCSError),
PCSErrors(PCSError),
}
impl From<ark_serialize::SerializationError> for PolyIOPErrors {
@@ -37,8 +37,8 @@ impl From<ark_serialize::SerializationError> for PolyIOPErrors {
}
}
impl From<TranscriptErrors> for PolyIOPErrors {
fn from(e: TranscriptErrors) -> Self {
impl From<TranscriptError> for PolyIOPErrors {
fn from(e: TranscriptError) -> Self {
Self::TranscriptErrors(e)
}
}
@@ -51,6 +51,6 @@ impl From<ArithErrors> for PolyIOPErrors {
impl From<PCSError> for PolyIOPErrors {
fn from(e: PCSError) -> Self {
Self::PCSError(e)
Self::PCSErrors(e)
}
}

View File

@@ -5,7 +5,7 @@ use crate::{errors::PolyIOPErrors, prelude::ProductCheck, PolyIOP};
use ark_ec::PairingEngine;
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer};
use jf_primitives::pcs::PolynomialCommitmentScheme;
use pcs::PolynomialCommitmentScheme;
use std::rc::Rc;
use transcript::IOPTranscript;
@@ -153,20 +153,16 @@ where
#[cfg(test)]
mod test {
use super::PermutationCheck;
use crate::{
errors::PolyIOPErrors,
prelude::{identity_permutation_mle, random_permutation_mle},
PolyIOP,
};
use arithmetic::VPAuxInfo;
use crate::{errors::PolyIOPErrors, PolyIOP};
use arithmetic::{evaluate_opt, identity_permutation_mle, random_permutation_mle, VPAuxInfo};
use ark_bls12_381::Bls12_381;
use ark_ec::PairingEngine;
use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
use ark_std::test_rng;
use jf_primitives::pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme};
use pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme};
use std::{marker::PhantomData, rc::Rc};
type Kzg = MultilinearKzgPCS<Bls12_381>;
type KZG = MultilinearKzgPCS<Bls12_381>;
fn test_permutation_check_helper<E, PCS>(
pcs_param: &PCS::ProverParam,
@@ -206,10 +202,10 @@ mod test {
)?;
// check product subclaim
if prod_x
.evaluate(&perm_check_sub_claim.product_check_sub_claim.final_query.0)
.unwrap()
!= perm_check_sub_claim.product_check_sub_claim.final_query.1
if evaluate_opt(
&prod_x,
&perm_check_sub_claim.product_check_sub_claim.final_query.0,
) != perm_check_sub_claim.product_check_sub_claim.final_query.1
{
return Err(PolyIOPErrors::InvalidVerifier("wrong subclaim".to_string()));
};
@@ -228,7 +224,7 @@ mod test {
let w = Rc::new(DenseMultilinearExtension::rand(nv, &mut rng));
// s_perm is the identity map
let s_perm = identity_permutation_mle(nv);
test_permutation_check_helper::<Bls12_381, Kzg>(&pcs_param, &w, &w, &s_perm)?;
test_permutation_check_helper::<Bls12_381, KZG>(&pcs_param, &w, &w, &s_perm)?;
}
{
@@ -238,9 +234,9 @@ mod test {
let s_perm = random_permutation_mle(nv, &mut rng);
if nv == 1 {
test_permutation_check_helper::<Bls12_381, Kzg>(&pcs_param, &w, &w, &s_perm)?;
test_permutation_check_helper::<Bls12_381, KZG>(&pcs_param, &w, &w, &s_perm)?;
} else {
assert!(test_permutation_check_helper::<Bls12_381, Kzg>(
assert!(test_permutation_check_helper::<Bls12_381, KZG>(
&pcs_param, &w, &w, &s_perm
)
.is_err());
@@ -255,7 +251,7 @@ mod test {
let s_perm = identity_permutation_mle(nv);
assert!(
test_permutation_check_helper::<Bls12_381, Kzg>(&pcs_param, &f, &g, &s_perm)
test_permutation_check_helper::<Bls12_381, KZG>(&pcs_param, &f, &g, &s_perm)
.is_err()
);
}

View File

@@ -1,9 +1,10 @@
//! This module implements useful functions for the permutation check protocol.
use crate::errors::PolyIOPErrors;
use arithmetic::identity_permutation_mle;
use ark_ff::PrimeField;
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, rand::RngCore, start_timer};
use ark_std::{end_timer, start_timer};
use std::rc::Rc;
/// Returns the evaluations of two MLEs:
@@ -58,30 +59,3 @@ pub(super) fn computer_num_and_denom<F: PrimeField>(
end_timer!(start);
Ok((numerator, denominator))
}
/// An MLE that represent an identity permutation: `f(index) \mapto index`
pub fn identity_permutation_mle<F: PrimeField>(
num_vars: usize,
) -> Rc<DenseMultilinearExtension<F>> {
let s_id_vec = (0..1u64 << num_vars).map(F::from).collect();
Rc::new(DenseMultilinearExtension::from_evaluations_vec(
num_vars, s_id_vec,
))
}
/// An MLE that represent a random permutation
pub fn random_permutation_mle<F: PrimeField, R: RngCore>(
num_vars: usize,
rng: &mut R,
) -> Rc<DenseMultilinearExtension<F>> {
let len = 1u64 << num_vars;
let mut s_id_vec: Vec<F> = (0..len).map(F::from).collect();
let mut s_perm_vec = vec![];
for _ in 0..len {
let index = rng.next_u64() as usize % s_id_vec.len();
s_perm_vec.push(s_id_vec.remove(index));
}
Rc::new(DenseMultilinearExtension::from_evaluations_vec(
num_vars, s_perm_vec,
))
}

View File

@@ -1,12 +1,4 @@
pub use crate::{
errors::PolyIOPErrors,
perm_check::{
util::{identity_permutation_mle, random_permutation_mle},
PermutationCheck,
},
prod_check::ProductCheck,
sum_check::SumCheck,
utils::*,
zero_check::ZeroCheck,
PolyIOP,
errors::PolyIOPErrors, perm_check::PermutationCheck, prod_check::ProductCheck,
sum_check::SumCheck, utils::*, zero_check::ZeroCheck, PolyIOP,
};

View File

@@ -11,7 +11,7 @@ use ark_ec::PairingEngine;
use ark_ff::{One, PrimeField, Zero};
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer};
use jf_primitives::pcs::prelude::PolynomialCommitmentScheme;
use pcs::PolynomialCommitmentScheme;
use std::rc::Rc;
use transcript::IOPTranscript;
@@ -207,7 +207,7 @@ mod test {
use ark_ec::PairingEngine;
use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
use ark_std::test_rng;
use jf_primitives::pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme};
use pcs::{prelude::MultilinearKzgPCS, PolynomialCommitmentScheme};
use std::{marker::PhantomData, rc::Rc};
// f and g are guaranteed to have the same product

View File

@@ -1,12 +1,11 @@
//! This module implements useful functions for the product check protocol.
use crate::{
errors::PolyIOPErrors, structs::IOPProof, utils::get_index, zero_check::ZeroCheck, PolyIOP,
};
use arithmetic::VirtualPolynomial;
use crate::{errors::PolyIOPErrors, structs::IOPProof, zero_check::ZeroCheck, PolyIOP};
use arithmetic::{get_index, VirtualPolynomial};
use ark_ff::PrimeField;
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer};
use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
use std::rc::Rc;
use transcript::IOPTranscript;
@@ -191,10 +190,12 @@ fn compute_prod_0<F: PrimeField>(
) -> Result<Vec<F>, PolyIOPErrors> {
let start = start_timer!(|| "compute prod(0,x)");
let mut prod_0x_evals = vec![];
for (&fi, &gi) in fx.iter().zip(gx.iter()) {
prod_0x_evals.push(fi / gi);
}
let input = fx
.iter()
.zip(gx.iter())
.map(|(&fi, &gi)| (fi, gi))
.collect::<Vec<_>>();
let prod_0x_evals = input.par_iter().map(|(x, y)| *x / *y).collect::<Vec<_>>();
end_timer!(start);
Ok(prod_0x_evals)

View File

@@ -5,14 +5,15 @@ use crate::{
errors::PolyIOPErrors,
structs::{IOPProverMessage, IOPProverState},
};
use arithmetic::VirtualPolynomial;
use arithmetic::{fix_variables, VirtualPolynomial};
use ark_ff::PrimeField;
use ark_poly::{DenseMultilinearExtension, MultilinearExtension};
use ark_poly::DenseMultilinearExtension;
use ark_std::{end_timer, start_timer, vec::Vec};
use rayon::prelude::IntoParallelIterator;
use std::rc::Rc;
#[cfg(feature = "parallel")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator};
impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
type VirtualPolynomial = VirtualPolynomial<F>;
@@ -44,8 +45,9 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
&mut self,
challenge: &Option<F>,
) -> Result<Self::ProverMessage, PolyIOPErrors> {
let start =
start_timer!(|| format!("sum check prove {}-th round and update state", self.round));
// let start =
// start_timer!(|| format!("sum check prove {}-th round and update state",
// self.round));
if self.round >= self.poly.aux_info.num_variables {
return Err(PolyIOPErrors::InvalidProver(
@@ -53,7 +55,7 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
));
}
let fix_argument = start_timer!(|| "fix argument");
// let fix_argument = start_timer!(|| "fix argument");
// Step 1:
// fix argument and evaluate f(x) over x_m = r; where r is the challenge
@@ -85,18 +87,18 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
#[cfg(feature = "parallel")]
flattened_ml_extensions
.par_iter_mut()
.for_each(|mle| *mle = mle.fix_variables(&[r]));
.for_each(|mle| *mle = fix_variables(mle, &[r]));
#[cfg(not(feature = "parallel"))]
flattened_ml_extensions
.iter_mut()
.for_each(|mle| *mle = mle.fix_variables(&[r]));
.for_each(|mle| *mle = fix_variables(mle, &[r]));
} else if self.round > 0 {
return Err(PolyIOPErrors::InvalidProver(
"verifier message is empty".to_string(),
));
}
end_timer!(fix_argument);
// end_timer!(fix_argument);
self.round += 1;
@@ -104,29 +106,44 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
let mut products_sum = Vec::with_capacity(self.poly.aux_info.max_degree + 1);
products_sum.resize(self.poly.aux_info.max_degree + 1, F::zero());
let compute_sum = start_timer!(|| "compute sum");
// let compute_sum = start_timer!(|| "compute sum");
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
#[cfg(feature = "parallel")]
products_sum.par_iter_mut().enumerate().for_each(|(t, e)| {
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
*e += product;
}
for (t, e) in products_sum.iter_mut().enumerate() {
let t = F::from(t as u64);
let one_minus_t = F::one() - t;
let products = (0..1 << (self.poly.aux_info.num_variables - self.round))
.into_par_iter()
.map(|b| {
// evaluate P_round(t)
let mut tmp = F::zero();
products_list.iter().for_each(|(coefficient, products)| {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * one_minus_t + table[(b << 1) + 1] * t;
}
tmp += product;
});
tmp
})
.collect::<Vec<F>>();
for i in products.iter() {
*e += i
}
});
}
#[cfg(not(feature = "parallel"))]
products_sum.iter_mut().enumerate().for_each(|(t, e)| {
let t = F::from(t as u64);
let one_minus_t = F::one() - t;
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
@@ -134,8 +151,7 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
product *= table[b << 1] * one_minus_t + table[(b << 1) + 1] * t;
}
*e += product;
}
@@ -148,8 +164,8 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
.map(|x| Rc::new(x.clone()))
.collect();
end_timer!(compute_sum);
end_timer!(start);
// end_timer!(compute_sum);
// end_timer!(start);
Ok(IOPProverMessage {
evaluations: products_sum,
})

View File

@@ -11,53 +11,11 @@ macro_rules! to_bytes {
}};
}
/// Decompose an integer into a binary vector in little endian.
#[allow(dead_code)]
pub fn bit_decompose(input: u64, num_var: usize) -> Vec<bool> {
let mut res = Vec::with_capacity(num_var);
let mut i = input;
for _ in 0..num_var {
res.push(i & 1 == 1);
i >>= 1;
}
res
}
/// Project a little endian binary vector into an integer.
#[allow(dead_code)]
pub(crate) fn project(input: &[bool]) -> u64 {
let mut res = 0;
for &e in input.iter().rev() {
res <<= 1;
res += e as u64;
}
res
}
// Input index
// - `i := (i_0, ...i_{n-1})`,
// - `num_vars := n`
// return three elements:
// - `x0 := (i_1, ..., i_{n-1}, 0)`
// - `x1 := (i_1, ..., i_{n-1}, 1)`
// - `sign := i_0`
#[inline]
pub(crate) fn get_index(i: usize, num_vars: usize) -> (usize, usize, bool) {
let bit_sequence = bit_decompose(i as u64, num_vars);
// the last bit comes first here because of LE encoding
let x0 = project(&[[false].as_ref(), bit_sequence[..num_vars - 1].as_ref()].concat()) as usize;
let x1 = project(&[[true].as_ref(), bit_sequence[..num_vars - 1].as_ref()].concat()) as usize;
(x0, x1, bit_sequence[num_vars - 1])
}
#[cfg(test)]
mod test {
use super::{bit_decompose, get_index, project};
use ark_bls12_381::Fr;
use ark_serialize::CanonicalSerialize;
use ark_std::{rand::RngCore, test_rng, One};
use ark_std::One;
#[test]
fn test_to_bytes() {
@@ -67,35 +25,4 @@ mod test {
f1.serialize(&mut bytes).unwrap();
assert_eq!(bytes, to_bytes!(&f1).unwrap());
}
#[test]
fn test_decomposition() {
let mut rng = test_rng();
for _ in 0..100 {
let t = rng.next_u64();
let b = bit_decompose(t, 64);
let r = project(&b);
assert_eq!(t, r)
}
}
#[test]
fn test_get_index() {
let a = 0b1010;
let (x0, x1, sign) = get_index(a, 4);
assert_eq!(x0, 0b0100);
assert_eq!(x1, 0b0101);
assert!(sign);
let (x0, x1, sign) = get_index(a, 5);
assert_eq!(x0, 0b10100);
assert_eq!(x1, 0b10101);
assert!(!sign);
let a = 0b1111;
let (x0, x1, sign) = get_index(a, 4);
assert_eq!(x0, 0b1110);
assert_eq!(x1, 0b1111);
assert!(sign);
}
}