Browse Source

refactor r1csproof to non-zk sumcheck

master
Mara Mihali 2 years ago
parent
commit
1e9930ae79
2 changed files with 218 additions and 261 deletions
  1. +72
    -251
      src/r1csproof.rs
  2. +146
    -10
      src/sumcheck.rs

+ 72
- 251
src/r1csproof.rs

@ -1,12 +1,15 @@
#![allow(clippy::too_many_arguments)] #![allow(clippy::too_many_arguments)]
use crate::group::CompressedGroup; use crate::group::CompressedGroup;
use crate::sumcheck::SumcheckInstanceProof;
use super::commitments::{Commitments, MultiCommitGens}; use super::commitments::{Commitments, MultiCommitGens};
use super::dense_mlpoly::{ use super::dense_mlpoly::{
DensePolynomial, EqPolynomial, PolyCommitment, PolyCommitmentGens, PolyEvalProof, DensePolynomial, EqPolynomial, PolyCommitment, PolyCommitmentGens, PolyEvalProof,
}; };
use super::errors::ProofVerifyError; use super::errors::ProofVerifyError;
use super::group::{GroupElement, VartimeMultiscalarMul, CompressGroupElement, DecompressGroupElement};
use super::group::{
CompressGroupElement, DecompressGroupElement, GroupElement, VartimeMultiscalarMul,
};
use super::nizk::{EqualityProof, KnowledgeProof, ProductProof}; use super::nizk::{EqualityProof, KnowledgeProof, ProductProof};
use super::r1csinstance::R1CSInstance; use super::r1csinstance::R1CSInstance;
use super::random::RandomTape; use super::random::RandomTape;
@ -15,29 +18,23 @@ use super::sparse_mlpoly::{SparsePolyEntry, SparsePolynomial};
use super::sumcheck::ZKSumcheckInstanceProof; use super::sumcheck::ZKSumcheckInstanceProof;
use super::timer::Timer; use super::timer::Timer;
use super::transcript::{AppendToTranscript, ProofTranscript}; use super::transcript::{AppendToTranscript, ProofTranscript};
use core::iter;
use ark_ec::ProjectiveCurve;
use ark_ff::PrimeField; use ark_ff::PrimeField;
use merlin::Transcript;
use ark_serialize::*; use ark_serialize::*;
use ark_std::{Zero, One};
use ark_ec::{ProjectiveCurve};
use ark_std::{One, Zero};
use core::iter;
use merlin::Transcript;
#[derive(CanonicalSerialize, CanonicalDeserialize, Debug)] #[derive(CanonicalSerialize, CanonicalDeserialize, Debug)]
pub struct R1CSProof { pub struct R1CSProof {
comm_vars: PolyCommitment,
sc_proof_phase1: ZKSumcheckInstanceProof,
claims_phase2: (
CompressedGroup,
CompressedGroup,
CompressedGroup,
CompressedGroup,
),
pok_claims_phase2: (KnowledgeProof, ProductProof),
proof_eq_sc_phase1: EqualityProof,
sc_proof_phase2: ZKSumcheckInstanceProof,
comm_vars_at_ry: CompressedGroup,
proof_eval_vars_at_ry: PolyEvalProof,
proof_eq_sc_phase2: EqualityProof,
sc_proof_phase1: SumcheckInstanceProof,
claims_phase2: (Scalar, Scalar, Scalar, Scalar),
// pok_claims_phase2: (KnowledgeProof, ProductProof),
// proof_eq_sc_phase1: EqualityProof,
sc_proof_phase2: SumcheckInstanceProof,
eval_vars_at_ry: Scalar,
// proof_eval_vars_at_ry: PolyEvalProof,
// proof_eq_sc_phase2: EqualityProof,
} }
pub struct R1CSSumcheckGens { pub struct R1CSSumcheckGens {
@ -82,61 +79,43 @@ impl R1CSProof {
evals_Az: &mut DensePolynomial, evals_Az: &mut DensePolynomial,
evals_Bz: &mut DensePolynomial, evals_Bz: &mut DensePolynomial,
evals_Cz: &mut DensePolynomial, evals_Cz: &mut DensePolynomial,
gens: &R1CSSumcheckGens,
transcript: &mut Transcript, transcript: &mut Transcript,
random_tape: &mut RandomTape,
) -> (ZKSumcheckInstanceProof, Vec<Scalar>, Vec<Scalar>, Scalar) {
let comb_func = |poly_A_comp: &Scalar,
poly_B_comp: &Scalar,
poly_C_comp: &Scalar,
poly_D_comp: &Scalar|
-> Scalar { (*poly_A_comp) * ((*poly_B_comp) * poly_C_comp - poly_D_comp) };
let (sc_proof_phase_one, r, claims, blind_claim_postsc) =
ZKSumcheckInstanceProof::prove_cubic_with_additive_term(
&Scalar::zero(), // claim is zero
&Scalar::zero(), // blind for claim is also zero
num_rounds,
evals_tau,
evals_Az,
evals_Bz,
evals_Cz,
comb_func,
&gens.gens_1,
&gens.gens_4,
transcript,
random_tape,
);
) -> (SumcheckInstanceProof, Vec<Scalar>, Vec<Scalar>) {
let comb_func =
|poly_tau_comp: &Scalar,
poly_A_comp: &Scalar,
poly_B_comp: &Scalar,
poly_C_comp: &Scalar|
-> Scalar { (*poly_tau_comp) * ((*poly_A_comp) * poly_B_comp - poly_C_comp) };
let (sc_proof_phase_one, r, claims) = SumcheckInstanceProof::prove_cubic_with_additive_term(
&Scalar::zero(), // claim is zero
num_rounds,
evals_tau,
evals_Az,
evals_Bz,
evals_Cz,
comb_func,
transcript,
);
(sc_proof_phase_one, r, claims, blind_claim_postsc)
(sc_proof_phase_one, r, claims)
} }
fn prove_phase_two( fn prove_phase_two(
num_rounds: usize, num_rounds: usize,
claim: &Scalar, claim: &Scalar,
blind_claim: &Scalar,
evals_z: &mut DensePolynomial, evals_z: &mut DensePolynomial,
evals_ABC: &mut DensePolynomial, evals_ABC: &mut DensePolynomial,
gens: &R1CSSumcheckGens,
transcript: &mut Transcript, transcript: &mut Transcript,
random_tape: &mut RandomTape,
) -> (ZKSumcheckInstanceProof, Vec<Scalar>, Vec<Scalar>, Scalar) {
) -> (SumcheckInstanceProof, Vec<Scalar>, Vec<Scalar>) {
let comb_func = let comb_func =
|poly_A_comp: &Scalar, poly_B_comp: &Scalar| -> Scalar { (*poly_A_comp) * poly_B_comp }; |poly_A_comp: &Scalar, poly_B_comp: &Scalar| -> Scalar { (*poly_A_comp) * poly_B_comp };
let (sc_proof_phase_two, r, claims, blind_claim_postsc) = ZKSumcheckInstanceProof::prove_quad(
claim,
blind_claim,
num_rounds,
evals_z,
evals_ABC,
comb_func,
&gens.gens_1,
&gens.gens_3,
transcript,
random_tape,
let (sc_proof_phase_two, r, claims) = SumcheckInstanceProof::prove_quad(
claim, num_rounds, evals_z, evals_ABC, comb_func, transcript,
); );
(sc_proof_phase_two, r, claims, blind_claim_postsc)
(sc_proof_phase_two, r, claims)
} }
fn protocol_name() -> &'static [u8] { fn protocol_name() -> &'static [u8] {
@ -159,19 +138,7 @@ impl R1CSProof {
input.append_to_transcript(b"input", transcript); input.append_to_transcript(b"input", transcript);
let timer_commit = Timer::new("polycommit");
let (poly_vars, comm_vars, blinds_vars) = {
// create a multilinear polynomial using the supplied assignment for variables
let poly_vars = DensePolynomial::new(vars.clone());
// produce a commitment to the satisfying assignment
let (comm_vars, blinds_vars) = poly_vars.commit(&gens.gens_pc, Some(random_tape));
// add the commitment to the prover's transcript
comm_vars.append_to_transcript(b"poly_commitment", transcript);
(poly_vars, comm_vars, blinds_vars)
};
timer_commit.stop();
let poly_vars = DensePolynomial::new(vars.clone());
let timer_sc_proof_phase1 = Timer::new("prove_sc_phase_one"); let timer_sc_proof_phase1 = Timer::new("prove_sc_phase_one");
@ -195,15 +162,13 @@ impl R1CSProof {
let (mut poly_Az, mut poly_Bz, mut poly_Cz) = let (mut poly_Az, mut poly_Bz, mut poly_Cz) =
inst.multiply_vec(inst.get_num_cons(), z.len(), &z); inst.multiply_vec(inst.get_num_cons(), z.len(), &z);
let (sc_proof_phase1, rx, _claims_phase1, blind_claim_postsc1) = R1CSProof::prove_phase_one(
let (sc_proof_phase1, rx, _claims_phase1) = R1CSProof::prove_phase_one(
num_rounds_x, num_rounds_x,
&mut poly_tau, &mut poly_tau,
&mut poly_Az, &mut poly_Az,
&mut poly_Bz, &mut poly_Bz,
&mut poly_Cz, &mut poly_Cz,
&gens.gens_sc,
transcript, transcript,
random_tape,
); );
assert_eq!(poly_tau.len(), 1); assert_eq!(poly_tau.len(), 1);
assert_eq!(poly_Az.len(), 1); assert_eq!(poly_Az.len(), 1);
@ -213,56 +178,11 @@ impl R1CSProof {
let (tau_claim, Az_claim, Bz_claim, Cz_claim) = let (tau_claim, Az_claim, Bz_claim, Cz_claim) =
(&poly_tau[0], &poly_Az[0], &poly_Bz[0], &poly_Cz[0]); (&poly_tau[0], &poly_Az[0], &poly_Bz[0], &poly_Cz[0]);
let (Az_blind, Bz_blind, Cz_blind, prod_Az_Bz_blind) = (
random_tape.random_scalar(b"Az_blind"),
random_tape.random_scalar(b"Bz_blind"),
random_tape.random_scalar(b"Cz_blind"),
random_tape.random_scalar(b"prod_Az_Bz_blind"),
);
let (pok_Cz_claim, comm_Cz_claim) = {
KnowledgeProof::prove(
&gens.gens_sc.gens_1,
transcript,
random_tape,
Cz_claim,
&Cz_blind,
)
};
let (proof_prod, comm_Az_claim, comm_Bz_claim, comm_prod_Az_Bz_claims) = {
let prod = (*Az_claim) * Bz_claim;
ProductProof::prove(
&gens.gens_sc.gens_1,
transcript,
random_tape,
Az_claim,
&Az_blind,
Bz_claim,
&Bz_blind,
&prod,
&prod_Az_Bz_blind,
)
};
comm_Az_claim.append_to_transcript(b"comm_Az_claim", transcript);
comm_Bz_claim.append_to_transcript(b"comm_Bz_claim", transcript);
comm_Cz_claim.append_to_transcript(b"comm_Cz_claim", transcript);
comm_prod_Az_Bz_claims.append_to_transcript(b"comm_prod_Az_Bz_claims", transcript);
let prod_Az_Bz_claims = (*Az_claim) * Bz_claim;
// prove the final step of sum-check #1 // prove the final step of sum-check #1
let taus_bound_rx = tau_claim; let taus_bound_rx = tau_claim;
let blind_expected_claim_postsc1 = (prod_Az_Bz_blind - Cz_blind) * taus_bound_rx;
let claim_post_phase1 = ((*Az_claim) * Bz_claim - Cz_claim) * taus_bound_rx; let claim_post_phase1 = ((*Az_claim) * Bz_claim - Cz_claim) * taus_bound_rx;
let (proof_eq_sc_phase1, _C1, _C2) = EqualityProof::prove(
&gens.gens_sc.gens_1,
transcript,
random_tape,
&claim_post_phase1,
&blind_expected_claim_postsc1,
&claim_post_phase1,
&blind_claim_postsc1,
);
let timer_sc_proof_phase2 = Timer::new("prove_sc_phase_two"); let timer_sc_proof_phase2 = Timer::new("prove_sc_phase_two");
// combine the three claims into a single claim // combine the three claims into a single claim
@ -270,7 +190,6 @@ impl R1CSProof {
let r_B = transcript.challenge_scalar(b"challenege_Bz"); let r_B = transcript.challenge_scalar(b"challenege_Bz");
let r_C = transcript.challenge_scalar(b"challenege_Cz"); let r_C = transcript.challenge_scalar(b"challenege_Cz");
let claim_phase2 = r_A * Az_claim + r_B * Bz_claim + r_C * Cz_claim; let claim_phase2 = r_A * Az_claim + r_B * Bz_claim + r_C * Cz_claim;
let blind_claim_phase2 = r_A * Az_blind + r_B * Bz_blind + r_C * Cz_blind;
let evals_ABC = { let evals_ABC = {
// compute the initial evaluation table for R(\tau, x) // compute the initial evaluation table for R(\tau, x)
@ -286,65 +205,27 @@ impl R1CSProof {
}; };
// another instance of the sum-check protocol // another instance of the sum-check protocol
let (sc_proof_phase2, ry, claims_phase2, blind_claim_postsc2) = R1CSProof::prove_phase_two(
let (sc_proof_phase2, ry, claims_phase2) = R1CSProof::prove_phase_two(
num_rounds_y, num_rounds_y,
&claim_phase2, &claim_phase2,
&blind_claim_phase2,
&mut DensePolynomial::new(z), &mut DensePolynomial::new(z),
&mut DensePolynomial::new(evals_ABC), &mut DensePolynomial::new(evals_ABC),
&gens.gens_sc,
transcript, transcript,
random_tape,
); );
timer_sc_proof_phase2.stop(); timer_sc_proof_phase2.stop();
let timer_polyeval = Timer::new("polyeval"); let timer_polyeval = Timer::new("polyeval");
let eval_vars_at_ry = poly_vars.evaluate(&ry[1..].to_vec()); let eval_vars_at_ry = poly_vars.evaluate(&ry[1..].to_vec());
let blind_eval = random_tape.random_scalar(b"blind_eval");
let (proof_eval_vars_at_ry, comm_vars_at_ry) = PolyEvalProof::prove(
&poly_vars,
Some(&blinds_vars),
&ry[1..].to_vec(),
&eval_vars_at_ry,
Some(&blind_eval),
&gens.gens_pc,
transcript,
random_tape,
);
timer_polyeval.stop(); timer_polyeval.stop();
// prove the final step of sum-check #2
let blind_eval_Z_at_ry = (Scalar::one() - ry[0]) * blind_eval;
let blind_expected_claim_postsc2 = claims_phase2[1] * blind_eval_Z_at_ry;
let claim_post_phase2 = claims_phase2[0] * claims_phase2[1];
let (proof_eq_sc_phase2, _C1, _C2) = EqualityProof::prove(
&gens.gens_pc.gens.gens_1,
transcript,
random_tape,
&claim_post_phase2,
&blind_expected_claim_postsc2,
&claim_post_phase2,
&blind_claim_postsc2,
);
timer_prove.stop(); timer_prove.stop();
( (
R1CSProof { R1CSProof {
comm_vars,
sc_proof_phase1, sc_proof_phase1,
claims_phase2: (
comm_Az_claim,
comm_Bz_claim,
comm_Cz_claim,
comm_prod_Az_Bz_claims,
),
pok_claims_phase2: (pok_Cz_claim, proof_prod),
proof_eq_sc_phase1,
claims_phase2: (*Az_claim, *Bz_claim, *Cz_claim, prod_Az_Bz_claims),
sc_proof_phase2, sc_proof_phase2,
comm_vars_at_ry,
proof_eval_vars_at_ry,
proof_eq_sc_phase2,
eval_vars_at_ry,
}, },
rx, rx,
ry, ry,
@ -365,10 +246,6 @@ impl R1CSProof {
input.append_to_transcript(b"input", transcript); input.append_to_transcript(b"input", transcript);
let n = num_vars; let n = num_vars;
// add the commitment to the verifier's transcript
self
.comm_vars
.append_to_transcript(b"poly_commitment", transcript);
let (num_rounds_x, num_rounds_y) = (num_cons.log2() as usize, (2 * num_vars).log2() as usize); let (num_rounds_x, num_rounds_y) = (num_cons.log2() as usize, (2 * num_vars).log2() as usize);
@ -376,85 +253,35 @@ impl R1CSProof {
let tau = transcript.challenge_vector(b"challenge_tau", num_rounds_x); let tau = transcript.challenge_vector(b"challenge_tau", num_rounds_x);
// verify the first sum-check instance // verify the first sum-check instance
let claim_phase1 = Scalar::zero()
.commit(&Scalar::zero(), &gens.gens_sc.gens_1)
.compress();
let (comm_claim_post_phase1, rx) = self.sc_proof_phase1.verify(
&claim_phase1,
num_rounds_x,
3,
&gens.gens_sc.gens_1,
&gens.gens_sc.gens_4,
transcript,
)?;
// perform the intermediate sum-check test with claimed Az, Bz, and Cz
let (comm_Az_claim, comm_Bz_claim, comm_Cz_claim, comm_prod_Az_Bz_claims) = &self.claims_phase2;
let (pok_Cz_claim, proof_prod) = &self.pok_claims_phase2;
let claim_phase1 = Scalar::zero();
let (claim_post_phase1, rx) =
self
.sc_proof_phase1
.verify(claim_phase1, num_rounds_x, 3, transcript)?;
pok_Cz_claim.verify(&gens.gens_sc.gens_1, transcript, comm_Cz_claim)?;
proof_prod.verify(
&gens.gens_sc.gens_1,
transcript,
comm_Az_claim,
comm_Bz_claim,
comm_prod_Az_Bz_claims,
)?;
comm_Az_claim.append_to_transcript(b"comm_Az_claim", transcript);
comm_Bz_claim.append_to_transcript(b"comm_Bz_claim", transcript);
comm_Cz_claim.append_to_transcript(b"comm_Cz_claim", transcript);
comm_prod_Az_Bz_claims.append_to_transcript(b"comm_prod_Az_Bz_claims", transcript);
// perform the intermediate sum-check test with claimed Az, Bz, and Cz
let (Az_claim, Bz_claim, Cz_claim, prod_Az_Bz_claims) = &self.claims_phase2;
let taus_bound_rx: Scalar = (0..rx.len()) let taus_bound_rx: Scalar = (0..rx.len())
.map(|i| rx[i] * tau[i] + (Scalar::one() - rx[i]) * (Scalar::one() - tau[i])) .map(|i| rx[i] * tau[i] + (Scalar::one() - rx[i]) * (Scalar::one() - tau[i]))
.product(); .product();
let expected_claim_post_phase1 = (GroupElement::decompress(comm_prod_Az_Bz_claims).unwrap() - GroupElement::decompress(comm_Cz_claim).unwrap()).mul(taus_bound_rx.into_repr())
.compress();
// verify proof that expected_claim_post_phase1 == claim_post_phase1
self.proof_eq_sc_phase1.verify(
&gens.gens_sc.gens_1,
transcript,
&expected_claim_post_phase1,
&comm_claim_post_phase1,
)?;
let expected_claim_post_phase1 = (*prod_Az_Bz_claims - *Cz_claim) * (taus_bound_rx);
assert_eq!(claim_post_phase1, expected_claim_post_phase1);
// derive three public challenges and then derive a joint claim // derive three public challenges and then derive a joint claim
let r_A = transcript.challenge_scalar(b"challenege_Az"); let r_A = transcript.challenge_scalar(b"challenege_Az");
let r_B = transcript.challenge_scalar(b"challenege_Bz"); let r_B = transcript.challenge_scalar(b"challenege_Bz");
let r_C = transcript.challenge_scalar(b"challenege_Cz"); let r_C = transcript.challenge_scalar(b"challenege_Cz");
// r_A * comm_Az_claim + r_B * comm_Bz_claim + r_C * comm_Cz_claim;
let comm_claim_phase2 = GroupElement::vartime_multiscalar_mul(
iter::once(&r_A)
.chain(iter::once(&r_B))
.chain(iter::once(&r_C)).map(|s| (*s)).collect::<Vec<Scalar>>().as_slice(),
iter::once(&comm_Az_claim)
.chain(iter::once(&comm_Bz_claim))
.chain(iter::once(&comm_Cz_claim))
.map(|pt| GroupElement::decompress(pt).unwrap())
.collect::<Vec<GroupElement>>().as_slice(),
)
.compress();
let claim_phase2 = r_A * Az_claim + r_B * Bz_claim + r_C * Cz_claim;
// verify the joint claim with a sum-check protocol // verify the joint claim with a sum-check protocol
let (comm_claim_post_phase2, ry) = self.sc_proof_phase2.verify(
&comm_claim_phase2,
num_rounds_y,
2,
&gens.gens_sc.gens_1,
&gens.gens_sc.gens_3,
transcript,
)?;
// verify Z(ry) proof against the initial commitment
self.proof_eval_vars_at_ry.verify(
&gens.gens_pc,
transcript,
&ry[1..].to_vec(),
&self.comm_vars_at_ry,
&self.comm_vars,
)?;
let (claim_post_phase2, ry) =
self
.sc_proof_phase2
.verify(claim_phase2, num_rounds_y, 2, transcript)?;
let poly_input_eval = { let poly_input_eval = {
// constant term // constant term
@ -469,26 +296,14 @@ impl R1CSProof {
.evaluate(&ry[1..].to_vec()) .evaluate(&ry[1..].to_vec())
}; };
// compute commitment to eval_Z_at_ry = (Scalar::one() - ry[0]) * self.eval_vars_at_ry + ry[0] * poly_input_eval
let comm_eval_Z_at_ry = GroupElement::vartime_multiscalar_mul(
iter::once(Scalar::one() - ry[0]).chain(iter::once(ry[0])).collect::<Vec<Scalar>>().as_slice(),
iter::once(GroupElement::decompress(&self.comm_vars_at_ry).unwrap()).chain(iter::once(
poly_input_eval.commit(&Scalar::zero(), &gens.gens_pc.gens.gens_1),
)).collect::<Vec<GroupElement>>().as_slice(),
);
let eval_Z_at_ry = (Scalar::one() - ry[0]) * self.eval_vars_at_ry + ry[0] * poly_input_eval;
// perform the final check in the second sum-check protocol // perform the final check in the second sum-check protocol
let (eval_A_r, eval_B_r, eval_C_r) = evals; let (eval_A_r, eval_B_r, eval_C_r) = evals;
let scalar = r_A * eval_A_r + r_B * eval_B_r + r_C * eval_C_r; let scalar = r_A * eval_A_r + r_B * eval_B_r + r_C * eval_C_r;
let expected_claim_post_phase2 =
comm_eval_Z_at_ry.mul(scalar.into_repr()).compress();
// verify proof that expected_claim_post_phase1 == claim_post_phase1
self.proof_eq_sc_phase2.verify(
&gens.gens_sc.gens_1,
transcript,
&expected_claim_post_phase2,
&comm_claim_post_phase2,
)?;
let expected_claim_post_phase2 = eval_Z_at_ry * scalar;
assert_eq!(expected_claim_post_phase2, claim_post_phase2);
Ok((rx, ry)) Ok((rx, ry))
} }
@ -497,7 +312,8 @@ impl R1CSProof {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use ark_std::{UniformRand};
use ark_std::UniformRand;
use test::Bencher;
fn produce_tiny_r1cs() -> (R1CSInstance, Vec<Scalar>, Vec<Scalar>) { fn produce_tiny_r1cs() -> (R1CSInstance, Vec<Scalar>, Vec<Scalar>) {
// three constraints over five variables Z1, Z2, Z3, Z4, and Z5 // three constraints over five variables Z1, Z2, Z3, Z4, and Z5
@ -533,7 +349,7 @@ use ark_std::{UniformRand};
let inst = R1CSInstance::new(num_cons, num_vars, num_inputs, &A, &B, &C); let inst = R1CSInstance::new(num_cons, num_vars, num_inputs, &A, &B, &C);
// compute a satisfying assignment // compute a satisfying assignment
let mut rng = ark_std::rand::thread_rng();
let mut rng = ark_std::rand::thread_rng();
let i0 = Scalar::rand(&mut rng); let i0 = Scalar::rand(&mut rng);
let i1 = Scalar::rand(&mut rng); let i1 = Scalar::rand(&mut rng);
let z1 = Scalar::rand(&mut rng); let z1 = Scalar::rand(&mut rng);
@ -604,4 +420,9 @@ use ark_std::{UniformRand};
) )
.is_ok()); .is_ok());
} }
#[bench]
fn bench_r1cs_proof(b: &mut Bencher) {
b.iter(|| check_r1cs_proof());
}
} }

+ 146
- 10
src/sumcheck.rs

@ -3,17 +3,20 @@
use super::commitments::{Commitments, MultiCommitGens}; use super::commitments::{Commitments, MultiCommitGens};
use super::dense_mlpoly::DensePolynomial; use super::dense_mlpoly::DensePolynomial;
use super::errors::ProofVerifyError; use super::errors::ProofVerifyError;
use super::group::{CompressedGroup, GroupElement, VartimeMultiscalarMul, CompressGroupElement, DecompressGroupElement};
use super::group::{
CompressGroupElement, CompressedGroup, DecompressGroupElement, GroupElement,
VartimeMultiscalarMul,
};
use super::nizk::DotProductProof; use super::nizk::DotProductProof;
use super::random::RandomTape; use super::random::RandomTape;
use super::scalar::Scalar; use super::scalar::Scalar;
use super::transcript::{AppendToTranscript, ProofTranscript}; use super::transcript::{AppendToTranscript, ProofTranscript};
use super::unipoly::{CompressedUniPoly, UniPoly}; use super::unipoly::{CompressedUniPoly, UniPoly};
use ark_ff::{One, Zero};
use ark_serialize::*;
use core::iter; use core::iter;
use itertools::izip; use itertools::izip;
use merlin::Transcript; use merlin::Transcript;
use ark_serialize::*;
use ark_ff::{One,Zero};
#[derive(CanonicalSerialize, CanonicalDeserialize, Debug)] #[derive(CanonicalSerialize, CanonicalDeserialize, Debug)]
pub struct SumcheckInstanceProof { pub struct SumcheckInstanceProof {
@ -130,7 +133,8 @@ impl ZKSumcheckInstanceProof {
iter::once(&comm_claim_per_round) iter::once(&comm_claim_per_round)
.chain(iter::once(&comm_eval)) .chain(iter::once(&comm_eval))
.map(|pt| GroupElement::decompress(pt).unwrap()) .map(|pt| GroupElement::decompress(pt).unwrap())
.collect::<Vec<GroupElement>>().as_slice(),
.collect::<Vec<GroupElement>>()
.as_slice(),
) )
.compress(); .compress();
@ -181,6 +185,83 @@ impl ZKSumcheckInstanceProof {
} }
impl SumcheckInstanceProof { impl SumcheckInstanceProof {
pub fn prove_cubic_with_additive_term<F>(
claim: &Scalar,
num_rounds: usize,
poly_tau: &mut DensePolynomial,
poly_A: &mut DensePolynomial,
poly_B: &mut DensePolynomial,
poly_C: &mut DensePolynomial,
comb_func: F,
transcript: &mut Transcript,
) -> (Self, Vec<Scalar>, Vec<Scalar>)
where
F: Fn(&Scalar, &Scalar, &Scalar, &Scalar) -> Scalar,
{
let mut e = *claim;
let mut r: Vec<Scalar> = Vec::new();
let mut cubic_polys: Vec<CompressedUniPoly> = Vec::new();
for j in 0..num_rounds {
let mut eval_point_0 = Scalar::zero();
let mut eval_point_2 = Scalar::zero();
let mut eval_point_3 = Scalar::zero();
let len = poly_tau.len() / 2;
for i in 0..len {
// eval 0: bound_func is A(low)
eval_point_0 += comb_func(&poly_tau[i], &poly_A[i], &poly_B[i], &poly_C[i]);
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_tau_bound_point = poly_tau[len + i] + poly_tau[len + i] - poly_tau[i];
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C[len + i] + poly_C[len + i] - poly_C[i];
eval_point_2 += comb_func(
&poly_tau_bound_point,
&poly_A_bound_point,
&poly_B_bound_point,
&poly_C_bound_point,
);
// eval 3: bound_func is -2A(low) + 3A(high); computed incrementally with bound_func applied to eval(2)
let poly_tau_bound_point = poly_tau_bound_point + poly_tau[len + i] - poly_tau[i];
let poly_A_bound_point = poly_A_bound_point + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B_bound_point + poly_B[len + i] - poly_B[i];
let poly_C_bound_point = poly_C_bound_point + poly_C[len + i] - poly_C[i];
eval_point_3 += comb_func(
&poly_tau_bound_point,
&poly_A_bound_point,
&poly_B_bound_point,
&poly_C_bound_point,
);
}
let evals = vec![eval_point_0, e - eval_point_0, eval_point_2, eval_point_3];
let poly = UniPoly::from_evals(&evals);
// append the prover's message to the transcript
poly.append_to_transcript(b"poly", transcript);
//derive the verifier's challenge for the next round
let r_j = transcript.challenge_scalar(b"challenge_nextround");
r.push(r_j);
// bound all tables to the verifier's challenege
poly_tau.bound_poly_var_top(&r_j);
poly_A.bound_poly_var_top(&r_j);
poly_B.bound_poly_var_top(&r_j);
poly_C.bound_poly_var_top(&r_j);
e = poly.evaluate(&r_j);
cubic_polys.push(poly.compress());
}
(
SumcheckInstanceProof::new(cubic_polys),
r,
vec![poly_tau[0], poly_A[0], poly_B[0], poly_C[0]],
)
}
pub fn prove_cubic<F>( pub fn prove_cubic<F>(
claim: &Scalar, claim: &Scalar,
num_rounds: usize, num_rounds: usize,
@ -423,6 +504,60 @@ impl SumcheckInstanceProof {
claims_dotp, claims_dotp,
) )
} }
pub fn prove_quad<F>(
claim: &Scalar,
num_rounds: usize,
poly_A: &mut DensePolynomial,
poly_B: &mut DensePolynomial,
comb_func: F,
transcript: &mut Transcript,
) -> (Self, Vec<Scalar>, Vec<Scalar>)
where
F: Fn(&Scalar, &Scalar) -> Scalar,
{
let mut e = *claim;
let mut r: Vec<Scalar> = Vec::new();
let mut quad_polys: Vec<CompressedUniPoly> = Vec::new();
for j in 0..num_rounds {
let mut eval_point_0 = Scalar::zero();
let mut eval_point_2 = Scalar::zero();
let len = poly_A.len() / 2;
for i in 0..len {
// eval 0: bound_func is A(low)
eval_point_0 += comb_func(&poly_A[i], &poly_B[i]);
// eval 2: bound_func is -A(low) + 2*A(high)
let poly_A_bound_point = poly_A[len + i] + poly_A[len + i] - poly_A[i];
let poly_B_bound_point = poly_B[len + i] + poly_B[len + i] - poly_B[i];
eval_point_2 += comb_func(&poly_A_bound_point, &poly_B_bound_point);
}
let evals = vec![eval_point_0, e - eval_point_0, eval_point_2];
let poly = UniPoly::from_evals(&evals);
// append the prover's message to the transcript
poly.append_to_transcript(b"poly", transcript);
//derive the verifier's challenge for the next round
let r_j = transcript.challenge_scalar(b"challenge_nextround");
r.push(r_j);
// bound all tables to the verifier's challenege
poly_A.bound_poly_var_top(&r_j);
poly_B.bound_poly_var_top(&r_j);
e = poly.evaluate(&r_j);
quad_polys.push(poly.compress());
}
(
SumcheckInstanceProof::new(quad_polys),
r,
vec![poly_A[0], poly_B[0]],
)
}
} }
impl ZKSumcheckInstanceProof { impl ZKSumcheckInstanceProof {
@ -514,7 +649,8 @@ impl ZKSumcheckInstanceProof {
iter::once(&comm_claim_per_round) iter::once(&comm_claim_per_round)
.chain(iter::once(&comm_eval)) .chain(iter::once(&comm_eval))
.map(|pt| GroupElement::decompress(pt).unwrap()) .map(|pt| GroupElement::decompress(pt).unwrap())
.collect::<Vec<GroupElement>>().as_slice(),
.collect::<Vec<GroupElement>>()
.as_slice(),
) )
.compress(); .compress();
@ -693,7 +829,6 @@ impl ZKSumcheckInstanceProof {
// add two claims to transcript // add two claims to transcript
comm_claim_per_round.append_to_transcript(b"comm_claim_per_round", transcript); comm_claim_per_round.append_to_transcript(b"comm_claim_per_round", transcript);
comm_eval.append_to_transcript(b"comm_eval", transcript); comm_eval.append_to_transcript(b"comm_eval", transcript);
// produce two weights // produce two weights
let w = transcript.challenge_vector(b"combine_two_claims_to_one", 2); let w = transcript.challenge_vector(b"combine_two_claims_to_one", 2);
@ -705,9 +840,11 @@ impl ZKSumcheckInstanceProof {
w.as_slice(), w.as_slice(),
iter::once(&comm_claim_per_round) iter::once(&comm_claim_per_round)
.chain(iter::once(&comm_eval)) .chain(iter::once(&comm_eval))
.map(|pt|GroupElement::decompress(&pt).unwrap())
.collect::<Vec<GroupElement>>().as_slice(),
).compress();
.map(|pt| GroupElement::decompress(&pt).unwrap())
.collect::<Vec<GroupElement>>()
.as_slice(),
)
.compress();
let blind = { let blind = {
let blind_sc = if j == 0 { let blind_sc = if j == 0 {
@ -721,7 +858,6 @@ impl ZKSumcheckInstanceProof {
w[0] * blind_sc + w[1] * blind_eval w[0] * blind_sc + w[1] * blind_eval
}; };
let res = target.commit(&blind, gens_1); let res = target.commit(&blind, gens_1);
assert_eq!(res.compress(), comm_target); assert_eq!(res.compress(), comm_target);

Loading…
Cancel
Save