polish IOP code base (#24)

This commit is contained in:
zhenfei
2022-05-20 12:30:32 -04:00
committed by GitHub
parent b9527f8e37
commit 97a89d7ecc
9 changed files with 445 additions and 361 deletions

View File

@@ -1,12 +1,10 @@
//! This module implements the sum check protocol.
//! Currently this is a simple wrapper of the sumcheck protocol
//! from Arkworks.
use crate::{
errors::PolyIOPErrors,
structs::{DomainInfo, IOPProof, IOPProverState, IOPVerifierState, SubClaim},
structs::{IOPProof, IOPProverState, IOPVerifierState, SubClaim},
transcript::IOPTranscript,
virtual_poly::VirtualPolynomial,
virtual_poly::{VPAuxInfo, VirtualPolynomial},
PolyIOP,
};
use ark_ff::PrimeField;
@@ -18,12 +16,12 @@ mod verifier;
/// Trait for doing sum check protocols.
pub trait SumCheck<F: PrimeField> {
type Proof;
type PolyList;
type DomainInfo;
type VirtualPolynomial;
type VPAuxInfo;
type SubClaim;
type Transcript;
/// extract sum from the proof
/// Extract sum from the proof
fn extract_sum(proof: &Self::Proof) -> F;
/// Initialize the system with a transcript
@@ -38,7 +36,7 @@ pub trait SumCheck<F: PrimeField> {
///
/// The polynomial is represented in the form of a VirtualPolynomial.
fn prove(
poly: &Self::PolyList,
poly: &Self::VirtualPolynomial,
transcript: &mut Self::Transcript,
) -> Result<Self::Proof, PolyIOPErrors>;
@@ -46,7 +44,7 @@ pub trait SumCheck<F: PrimeField> {
fn verify(
sum: F,
proof: &Self::Proof,
domain_info: &Self::DomainInfo,
aux_info: &Self::VPAuxInfo,
transcript: &mut Self::Transcript,
) -> Result<Self::SubClaim, PolyIOPErrors>;
}
@@ -56,17 +54,15 @@ pub trait SumCheckProver<F: PrimeField>
where
Self: Sized,
{
type PolyList;
type VirtualPolynomial;
type ProverMessage;
/// Initialize the prover to argue for the sum of polynomial over
/// {0,1}^`num_vars`
///
/// The polynomial is represented in the form of a VirtualPolynomial.
fn prover_init(polynomial: &Self::PolyList) -> Result<Self, PolyIOPErrors>;
/// Initialize the prover state to argue for the sum of the input polynomial
/// over {0,1}^`num_vars`.
fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result<Self, PolyIOPErrors>;
/// receive message from verifier, generate prover message, and proceed to
/// next round
/// Receive message from verifier, generate prover message, and proceed to
/// next round.
///
/// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
fn prove_round_and_update_state(
@@ -77,31 +73,33 @@ where
/// Trait for sum check protocol verifier side APIs.
pub trait SumCheckVerifier<F: PrimeField> {
type DomainInfo;
type VPAuxInfo;
type ProverMessage;
type Challenge;
type Transcript;
type SubClaim;
/// initialize the verifier
fn verifier_init(index_info: &Self::DomainInfo) -> Self;
/// Initialize the verifier's state.
fn verifier_init(index_info: &Self::VPAuxInfo) -> Self;
/// Run verifier at current round, given prover message
/// Run verifier for the current round, given a prover message.
///
/// Normally, this function should perform actual verification. Instead,
/// `verify_round` only samples and stores randomness and perform
/// verifications altogether in `check_and_generate_subclaim` at
/// the last step.
/// Note that `verify_round_and_update_state` only samples and stores
/// challenges; and update the verifier's state accordingly. The actual
/// verifications are deferred (in batch) to `check_and_generate_subclaim`
/// at the last step.
fn verify_round_and_update_state(
&mut self,
prover_msg: &Self::ProverMessage,
transcript: &mut Self::Transcript,
) -> Result<Self::Challenge, PolyIOPErrors>;
/// verify the sumcheck phase, and generate the subclaim
/// This function verifies the deferred checks in the interactive version of
/// the protocol; and generate the subclaim. Returns an error if the
/// proof failed to verify.
///
/// If the asserted sum is correct, then the multilinear polynomial
/// evaluated at `subclaim.point` is `subclaim.expected_evaluation`.
/// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`.
/// Otherwise, it is highly unlikely that those two will be equal.
/// Larger field size guarantees smaller soundness error.
fn check_and_generate_subclaim(
@@ -112,15 +110,12 @@ pub trait SumCheckVerifier<F: PrimeField> {
impl<F: PrimeField> SumCheck<F> for PolyIOP<F> {
type Proof = IOPProof<F>;
type PolyList = VirtualPolynomial<F>;
type DomainInfo = DomainInfo<F>;
type VirtualPolynomial = VirtualPolynomial<F>;
type VPAuxInfo = VPAuxInfo<F>;
type SubClaim = SubClaim<F>;
type Transcript = IOPTranscript<F>;
/// Extract sum from the proof
fn extract_sum(proof: &Self::Proof) -> F {
let start = start_timer!(|| "extract sum");
let res = proof.proofs[0].evaluations[0] + proof.proofs[0].evaluations[1];
@@ -145,17 +140,17 @@ impl<F: PrimeField> SumCheck<F> for PolyIOP<F> {
///
/// The polynomial is represented in the form of a VirtualPolynomial.
fn prove(
poly: &Self::PolyList,
poly: &Self::VirtualPolynomial,
transcript: &mut Self::Transcript,
) -> Result<Self::Proof, PolyIOPErrors> {
let start = start_timer!(|| "sum check prove");
transcript.append_domain_info(&poly.domain_info)?;
transcript.append_aux_info(&poly.aux_info)?;
let mut prover_state = IOPProverState::prover_init(poly)?;
let mut challenge = None;
let mut prover_msgs = Vec::with_capacity(poly.domain_info.num_variables);
for _ in 0..poly.domain_info.num_variables {
let mut prover_msgs = Vec::with_capacity(poly.aux_info.num_variables);
for _ in 0..poly.aux_info.num_variables {
let prover_msg =
IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge)?;
transcript.append_prover_message(&prover_msg)?;
@@ -169,18 +164,18 @@ impl<F: PrimeField> SumCheck<F> for PolyIOP<F> {
})
}
/// verify the claimed sum using the proof
/// Verify the claimed sum using the proof
fn verify(
claimed_sum: F,
proof: &Self::Proof,
domain_info: &Self::DomainInfo,
aux_info: &Self::VPAuxInfo,
transcript: &mut Self::Transcript,
) -> Result<Self::SubClaim, PolyIOPErrors> {
let start = start_timer!(|| "sum check verify");
transcript.append_domain_info(domain_info)?;
let mut verifier_state = IOPVerifierState::verifier_init(domain_info);
for i in 0..domain_info.num_variables {
transcript.append_aux_info(aux_info)?;
let mut verifier_state = IOPVerifierState::verifier_init(aux_info);
for i in 0..aux_info.num_variables {
let prover_msg = proof.proofs.get(i).expect("proof is incomplete");
transcript.append_prover_message(prover_msg)?;
IOPVerifierState::verify_round_and_update_state(
@@ -218,7 +213,7 @@ mod test {
let (poly, asserted_sum) =
VirtualPolynomial::rand(nv, num_multiplicands_range, num_products, &mut rng)?;
let proof = <PolyIOP<Fr> as SumCheck<Fr>>::prove(&poly, &mut transcript)?;
let poly_info = poly.domain_info.clone();
let poly_info = poly.aux_info.clone();
let mut transcript = <PolyIOP<Fr> as SumCheck<Fr>>::init_transcript();
let subclaim = <PolyIOP<Fr> as SumCheck<Fr>>::verify(
asserted_sum,
@@ -241,7 +236,7 @@ mod test {
let mut rng = test_rng();
let (poly, asserted_sum) =
VirtualPolynomial::<Fr>::rand(nv, num_multiplicands_range, num_products, &mut rng)?;
let poly_info = poly.domain_info.clone();
let poly_info = poly.aux_info.clone();
let mut prover_state = IOPProverState::prover_init(&poly)?;
let mut verifier_state = IOPVerifierState::verifier_init(&poly_info);
let mut challenge = None;
@@ -249,7 +244,7 @@ mod test {
transcript
.append_message(b"testing", b"initializing transcript for testing")
.unwrap();
for _ in 0..poly.domain_info.num_variables {
for _ in 0..poly.aux_info.num_variables {
let prover_message =
IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge)
.unwrap();
@@ -362,7 +357,7 @@ mod test {
drop(prover);
let mut transcript = <PolyIOP<Fr> as SumCheck<Fr>>::init_transcript();
let poly_info = poly.domain_info.clone();
let poly_info = poly.aux_info.clone();
let proof = <PolyIOP<Fr> as SumCheck<Fr>>::prove(&poly, &mut transcript)?;
let asserted_sum = <PolyIOP<Fr> as SumCheck<Fr>>::extract_sum(&proof);

View File

@@ -1,4 +1,4 @@
//! Prover
//! Prover subroutines for a SumCheck protocol.
use super::SumCheckProver;
use crate::{
@@ -15,14 +15,14 @@ use std::rc::Rc;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
type PolyList = VirtualPolynomial<F>;
type VirtualPolynomial = VirtualPolynomial<F>;
type ProverMessage = IOPProverMessage<F>;
/// Initialize the prover to argue for the sum of polynomial over
/// {0,1}^`num_vars`
fn prover_init(polynomial: &Self::PolyList) -> Result<Self, PolyIOPErrors> {
/// Initialize the prover state to argue for the sum of the input polynomial
/// over {0,1}^`num_vars`.
fn prover_init(polynomial: &Self::VirtualPolynomial) -> Result<Self, PolyIOPErrors> {
let start = start_timer!(|| "sum check prover init");
if polynomial.domain_info.num_variables == 0 {
if polynomial.aux_info.num_variables == 0 {
return Err(PolyIOPErrors::InvalidParameters(
"Attempt to prove a constant.".to_string(),
));
@@ -30,14 +30,14 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
end_timer!(start);
Ok(Self {
challenges: Vec::with_capacity(polynomial.domain_info.num_variables),
challenges: Vec::with_capacity(polynomial.aux_info.num_variables),
round: 0,
poly: polynomial.clone(),
})
}
/// Receive message from verifier, generate prover message, and proceed to
/// next round
/// next round.
///
/// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2).
fn prove_round_and_update_state(
@@ -47,8 +47,25 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
let start =
start_timer!(|| format!("sum check prove {}-th round and update state", self.round));
if self.round >= self.poly.aux_info.num_variables {
return Err(PolyIOPErrors::InvalidProver(
"Prover is not active".to_string(),
));
}
let fix_argument = start_timer!(|| "fix argument");
// Step 1:
// fix argument and evaluate f(x) over x_m = r; where r is the challenge
// for the current round, and m is the round number, indexed from 1
//
// i.e.:
// at round m <=n, for each mle g(x_1, ... x_n) within the flattened_mle
// which has already been evaluated to
//
// g(r_1, ..., r_{m-1}, x_m ... x_n)
//
// eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n)
let mut flattened_ml_extensions: Vec<DenseMultilinearExtension<F>> = self
.poly
.flattened_ml_extensions
@@ -64,18 +81,16 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
}
self.challenges.push(*chal);
// fix argument
let i = self.round;
let r = self.challenges[i - 1];
let r = self.challenges[self.round - 1];
#[cfg(feature = "parallel")]
flattened_ml_extensions
.par_iter_mut()
.for_each(|multiplicand| *multiplicand = multiplicand.fix_variables(&[r]));
.for_each(|mle| *mle = mle.fix_variables(&[r]));
#[cfg(not(feature = "parallel"))]
flattened_ml_extensions
.iter_mut()
.for_each(|multiplicand| *multiplicand = multiplicand.fix_variables(&[r]));
.for_each(|mle| *mle = mle.fix_variables(&[r]));
} else if self.round > 0 {
return Err(PolyIOPErrors::InvalidProver(
"verifier message is empty".to_string(),
@@ -85,30 +100,22 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
self.round += 1;
if self.round > self.poly.domain_info.num_variables {
return Err(PolyIOPErrors::InvalidProver(
"Prover is not active".to_string(),
));
}
let products_list = self.poly.products.clone();
let i = self.round;
let nv = self.poly.domain_info.num_variables;
let degree = self.poly.domain_info.max_degree; // the degree of univariate polynomial sent by prover at this round
let mut products_sum = Vec::with_capacity(degree + 1);
products_sum.resize(degree + 1, F::zero());
let mut products_sum = Vec::with_capacity(self.poly.aux_info.max_degree + 1);
products_sum.resize(self.poly.aux_info.max_degree + 1, F::zero());
let compute_sum = start_timer!(|| "compute sum");
// generate sum
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
#[cfg(feature = "parallel")]
products_sum.par_iter_mut().enumerate().for_each(|(t, e)| {
for b in 0..1 << (nv - i) {
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
let num_multiplicands = products.len();
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
@@ -119,26 +126,23 @@ impl<F: PrimeField> SumCheckProver<F> for IOPProverState<F> {
});
#[cfg(not(feature = "parallel"))]
for b in 0..1 << (nv - i) {
products_sum
.iter_mut()
.take(degree + 1)
.enumerate()
.for_each(|(t, e)| {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
let num_multiplicands = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_multiplicands) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
*e += product;
products_sum.iter_mut().enumerate().for_each(|(t, e)| {
for b in 0..1 << (self.poly.aux_info.num_variables - self.round) {
// evaluate P_round(t)
for (coefficient, products) in products_list.iter() {
let num_mles = products.len();
let mut product = *coefficient;
for &f in products.iter().take(num_mles) {
let table = &flattened_ml_extensions[f]; // f's range is checked in init
product *= table[b << 1] * (F::one() - F::from(t as u64))
+ table[(b << 1) + 1] * F::from(t as u64);
}
});
}
*e += product;
}
}
});
// update prover's state to the partial evaluated polynomial
self.poly.flattened_ml_extensions = flattened_ml_extensions
.iter()
.map(|x| Rc::new(x.clone()))

View File

@@ -1,11 +1,11 @@
// TODO: some of the struct is generic for Sum Checks and Zero Checks.
// If so move them to src/structs.rs
//! Verifier subroutines for a SumCheck protocol.
use super::SumCheckVerifier;
use crate::{
errors::PolyIOPErrors,
structs::{DomainInfo, IOPProverMessage, IOPVerifierState, SubClaim},
structs::{IOPProverMessage, IOPVerifierState, SubClaim},
transcript::IOPTranscript,
virtual_poly::VPAuxInfo,
};
use ark_ff::PrimeField;
use ark_std::{end_timer, start_timer};
@@ -14,14 +14,14 @@ use ark_std::{end_timer, start_timer};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
type DomainInfo = DomainInfo<F>;
type VPAuxInfo = VPAuxInfo<F>;
type ProverMessage = IOPProverMessage<F>;
type Challenge = F;
type Transcript = IOPTranscript<F>;
type SubClaim = SubClaim<F>;
/// initialize the verifier
fn verifier_init(index_info: &Self::DomainInfo) -> Self {
/// Initialize the verifier's state.
fn verifier_init(index_info: &Self::VPAuxInfo) -> Self {
let start = start_timer!(|| "sum check verifier init");
let res = Self {
round: 1,
@@ -35,12 +35,12 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
res
}
/// Run verifier at current round, given prover message
/// Run verifier for the current round, given a prover message.
///
/// Normally, this function should perform actual verification. Instead,
/// `verify_round` only samples and stores randomness and perform
/// verifications altogether in `check_and_generate_subclaim` at
/// the last step.
/// Note that `verify_round_and_update_state` only samples and stores
/// challenges; and update the verifier's state accordingly. The actual
/// verifications are deferred (in batch) to `check_and_generate_subclaim`
/// at the last step.
fn verify_round_and_update_state(
&mut self,
prover_msg: &Self::ProverMessage,
@@ -55,23 +55,24 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
));
}
// Now, verifier should check if the received P(0) + P(1) = expected. The check
// is moved to `check_and_generate_subclaim`, and will be done after the
// last round.
// In an interactive protocol, the verifier should
//
// 1. check if the received 'P(0) + P(1) = expected`.
// 2. set `expected` to P(r)`
//
// When we turn the protocol to a non-interactive one, it is sufficient to defer
// such checks to `check_and_generate_subclaim` after the last round.
let challenge = transcript.get_and_append_challenge(b"Internal round")?;
self.challenges.push(challenge);
self.polynomials_received
.push(prover_msg.evaluations.to_vec());
// Now, verifier should set `expected` to P(r).
// This operation is also moved to `check_and_generate_subclaim`,
// and will be done after the last round.
if self.round == self.num_vars {
// accept and close
self.finished = true;
} else {
// proceed to the next round
self.round += 1;
}
@@ -79,10 +80,12 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
Ok(challenge)
}
/// verify the sumcheck phase, and generate the subclaim
/// This function verifies the deferred checks in the interactive version of
/// the protocol; and generate the subclaim. Returns an error if the
/// proof failed to verify.
///
/// If the asserted sum is correct, then the multilinear polynomial
/// evaluated at `subclaim.point` is `subclaim.expected_evaluation`.
/// evaluated at `subclaim.point` will be `subclaim.expected_evaluation`.
/// Otherwise, it is highly unlikely that those two will be equal.
/// Larger field size guarantees smaller soundness error.
fn check_and_generate_subclaim(
@@ -102,6 +105,8 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
));
}
// the deferred check during the interactive phase:
// 2. set `expected` to P(r)`
#[cfg(feature = "parallel")]
let mut expected_vec = self
.polynomials_received
@@ -137,6 +142,7 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
interpolate_uni_poly::<F>(&evaluations, challenge)
})
.collect::<Result<Vec<_>, PolyIOPErrors>>()?;
// insert the asserted_sum to the first position of the expected vector
expected_vec.insert(0, *asserted_sum);
@@ -146,6 +152,8 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
.zip(expected_vec.iter())
.take(self.num_vars)
{
// the deferred check during the interactive phase:
// 1. check if the received 'P(0) + P(1) = expected`.
if evaluations[0] + evaluations[1] != expected {
return Err(PolyIOPErrors::InvalidProof(
"Prover message is not consistent with the claim.".to_string(),
@@ -154,8 +162,9 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
}
end_timer!(start);
Ok(SubClaim {
point: self.challenges.to_vec(),
// the last expected value (unchecked) will be included in the subclaim
point: self.challenges.clone(),
// the last expected value (not checked within this function) will be included in the
// subclaim
expected_evaluation: expected_vec[self.num_vars],
})
}
@@ -163,19 +172,20 @@ impl<F: PrimeField> SumCheckVerifier<F> for IOPVerifierState<F> {
/// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this
/// polynomial at `eval_at`:
///
/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) )
///
/// This implementation is linear in number of inputs in terms of field
/// operations. It also has a quadratic term in primitive operations which is
/// negligible compared to field operations.
pub(crate) fn interpolate_uni_poly<F: PrimeField>(
p_i: &[F],
eval_at: F,
) -> Result<F, PolyIOPErrors> {
fn interpolate_uni_poly<F: PrimeField>(p_i: &[F], eval_at: F) -> Result<F, PolyIOPErrors> {
let start = start_timer!(|| "sum check interpolate uni poly opt");
let mut res = F::zero();
// prod = \prod_{j!=i} (eval_at - j)
// compute
// - prod = \prod (eval_at - j)
// - evals = [eval_at - j]
let mut evals = vec![];
let len = p_i.len();
let mut prod = eval_at;
@@ -188,6 +198,7 @@ pub(crate) fn interpolate_uni_poly<F: PrimeField>(
}
for i in 0..len {
// res += p_i * prod / (divisor * (eval_at - j))
let divisor = get_divisor(i, len)?;
let divisor_f = {
if divisor < 0 {