diff --git a/shockwave_plus/src/lib.rs b/shockwave_plus/src/lib.rs index 2a031a5..fdb8804 100644 --- a/shockwave_plus/src/lib.rs +++ b/shockwave_plus/src/lib.rs @@ -62,12 +62,12 @@ impl ShockwavePlus { pub fn prove( &self, - witness: &[F], + r1cs_witness: &[F], transcript: &mut Transcript, ) -> (PartialSpartanProof, Vec) { // Compute the multilinear extension of the witness - assert!(witness.len().is_power_of_two()); - let witness_poly = SparseMLPoly::from_dense(witness.to_vec()); + assert!(r1cs_witness.len().is_power_of_two()); + let witness_poly = SparseMLPoly::from_dense(r1cs_witness.to_vec()); // Commit the witness polynomial let comm_witness_timer = start_timer!(|| "Commit witness"); @@ -75,10 +75,11 @@ impl ShockwavePlus { let witness_comm = committed_witness.committed_tree.root; end_timer!(comm_witness_timer); + // Add the witness commitment to the transcript transcript.append_bytes(&witness_comm); // ############################ - // Phase 1: The sum-checks + // Phase 1 // ################### let m = (self.r1cs.num_vars as f64).log2() as usize; @@ -86,25 +87,15 @@ impl ShockwavePlus { let mut tau_rev = tau.clone(); tau_rev.reverse(); - // First - // Compute the multilinear extension of the R1CS matrices. - // Prove that he Q_poly is a zero-polynomial - - // Q_poly is a zero-polynomial iff F_io evaluates to zero - // over the m-dimensional boolean hypercube.. - - // We prove using the sum-check protocol. - - // G_poly = A_poly * B_poly - C_poly - let num_rows = self.r1cs.num_cons; - let Az_poly = self.r1cs.A.mul_vector(num_rows, witness); - let Bz_poly = self.r1cs.B.mul_vector(num_rows, witness); - let Cz_poly = self.r1cs.C.mul_vector(num_rows, witness); + let Az_poly = self.r1cs.A.mul_vector(num_rows, r1cs_witness); + let Bz_poly = self.r1cs.B.mul_vector(num_rows, r1cs_witness); + let Cz_poly = self.r1cs.C.mul_vector(num_rows, r1cs_witness); - // Prove that the polynomial Q(t) - // \sum_{x \in {0, 1}^m} (Az_poly(x) * Bz_poly(x) - Cz_poly(x)) eq(tau, x) + // Prove that the + // Q(t) = \sum_{x \in {0, 1}^m} (Az_poly(x) * Bz_poly(x) - Cz_poly(x)) eq(t, x) // is a zero-polynomial using the sum-check protocol. + // We evaluate Q(t) at $\tau$ and check that it is zero. let rx = transcript.challenge_vec(m); let mut rx_rev = rx.clone(); @@ -119,7 +110,7 @@ impl ShockwavePlus { tau_rev.clone(), rx.clone(), ); - let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(transcript); + let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs_witness, transcript); end_timer!(sc_phase_1_timer); transcript.append_fe(&v_A); @@ -137,13 +128,13 @@ impl ShockwavePlus { self.r1cs.A.clone(), self.r1cs.B.clone(), self.r1cs.C.clone(), - witness.to_vec(), + r1cs_witness.to_vec(), rx.clone(), r.as_slice().try_into().unwrap(), ry.clone(), ); - let sc_proof_2 = sc_phase_2.prove(transcript); + let sc_proof_2 = sc_phase_2.prove(&self.pcs_witness, transcript); end_timer!(sc_phase_2_timer); let mut ry_rev = ry.clone(); @@ -190,10 +181,17 @@ impl ShockwavePlus { rx_rev.reverse(); transcript.append_fe(&partial_proof.sc_proof_1.blinder_poly_sum); + transcript.append_bytes(&partial_proof.sc_proof_1.blinder_poly_eval_proof.u_hat_comm); + let rho = transcript.challenge_fe(); let ex = SumCheckPhase1::verify_round_polys(&partial_proof.sc_proof_1, &rx, rho); + self.pcs_witness.verify( + &partial_proof.sc_proof_1.blinder_poly_eval_proof, + transcript, + ); + // The final eval should equal let v_A = partial_proof.v_A; let v_B = partial_proof.v_B; @@ -201,7 +199,7 @@ impl ShockwavePlus { let T_1_eq = EqPoly::new(tau); let T_1 = (v_A * v_B - v_C) * T_1_eq.eval(&rx_rev) - + rho * partial_proof.sc_proof_1.blinder_poly_eval_claim; + + rho * partial_proof.sc_proof_1.blinder_poly_eval_proof.y; assert_eq!(T_1, ex); transcript.append_fe(&v_A); @@ -216,6 +214,8 @@ impl ShockwavePlus { let ry = transcript.challenge_vec(m); transcript.append_fe(&partial_proof.sc_proof_2.blinder_poly_sum); + transcript.append_bytes(&partial_proof.sc_proof_2.blinder_poly_eval_proof.u_hat_comm); + let rho_2 = transcript.challenge_fe(); let T_2 = @@ -223,6 +223,11 @@ impl ShockwavePlus { let final_poly_eval = SumCheckPhase2::verify_round_polys(T_2, &partial_proof.sc_proof_2, &ry); + self.pcs_witness.verify( + &partial_proof.sc_proof_2.blinder_poly_eval_proof, + transcript, + ); + let mut ry_rev = ry.clone(); ry_rev.reverse(); @@ -234,14 +239,11 @@ impl ShockwavePlus { let B_eval = B_mle.eval(&rx_ry); let C_eval = C_mle.eval(&rx_ry); - self.pcs_witness.verify( - &partial_proof.z_eval_proof, - &partial_proof.z_comm, - transcript, - ); + self.pcs_witness + .verify(&partial_proof.z_eval_proof, transcript); let T_opened = (r_A * A_eval + r_B * B_eval + r_C * C_eval) * z_eval - + rho_2 * partial_proof.sc_proof_2.blinder_poly_eval_claim; + + rho_2 * partial_proof.sc_proof_2.blinder_poly_eval_proof.y; assert_eq!(T_opened, final_poly_eval); } } diff --git a/shockwave_plus/src/polynomial/blinder_poly.rs b/shockwave_plus/src/polynomial/blinder_poly.rs deleted file mode 100644 index fcd7560..0000000 --- a/shockwave_plus/src/polynomial/blinder_poly.rs +++ /dev/null @@ -1,34 +0,0 @@ -use crate::FieldExt; - -pub struct BlinderPoly { - inner_poly_coeffs: Vec>, -} - -impl BlinderPoly { - pub fn sample_random(num_vars: usize, degree: usize) -> Self { - let mut rng = rand::thread_rng(); - let inner_poly_coeffs = (0..num_vars) - .map(|_| (0..(degree + 1)).map(|_| F::random(&mut rng)).collect()) - .collect(); - - Self { inner_poly_coeffs } - } - - pub fn eval(&self, x: &[F]) -> F { - let mut res = F::ZERO; - - for (coeffs, x_i) in self.inner_poly_coeffs.iter().zip(x.iter()) { - let mut tmp = F::ZERO; - let mut x_i_pow = F::ONE; - - for coeff in coeffs.iter() { - tmp += *coeff * x_i_pow; - x_i_pow *= x_i; - } - - res += tmp; - } - - res - } -} diff --git a/shockwave_plus/src/polynomial/mod.rs b/shockwave_plus/src/polynomial/mod.rs index afb6a6d..423ae39 100644 --- a/shockwave_plus/src/polynomial/mod.rs +++ b/shockwave_plus/src/polynomial/mod.rs @@ -1,2 +1 @@ -pub mod blinder_poly; pub mod ml_poly; diff --git a/shockwave_plus/src/r1cs/r1cs.rs b/shockwave_plus/src/r1cs/r1cs.rs index d25e6b2..5afbc25 100644 --- a/shockwave_plus/src/r1cs/r1cs.rs +++ b/shockwave_plus/src/r1cs/r1cs.rs @@ -1,5 +1,4 @@ use crate::FieldExt; -use halo2curves::ff::Field; use tensor_pcs::SparseMLPoly; #[derive(Clone)] diff --git a/shockwave_plus/src/sumcheck/sc_phase_1.rs b/shockwave_plus/src/sumcheck/sc_phase_1.rs index 0a4e627..d62e47a 100644 --- a/shockwave_plus/src/sumcheck/sc_phase_1.rs +++ b/shockwave_plus/src/sumcheck/sc_phase_1.rs @@ -1,15 +1,14 @@ -use crate::polynomial::ml_poly::MlPoly; use crate::sumcheck::unipoly::UniPoly; use serde::{Deserialize, Serialize}; -use tensor_pcs::{EqPoly, Transcript}; +use tensor_pcs::{EqPoly, SparseMLPoly, TensorMLOpening, TensorMultilinearPCS, Transcript}; use crate::FieldExt; #[derive(Serialize, Deserialize)] pub struct SCPhase1Proof { pub blinder_poly_sum: F, - pub blinder_poly_eval_claim: F, pub round_polys: Vec>, + pub blinder_poly_eval_proof: TensorMLOpening, } pub struct SumCheckPhase1 { @@ -38,19 +37,30 @@ impl SumCheckPhase1 { } } - pub fn prove(&self, transcript: &mut Transcript) -> (SCPhase1Proof, (F, F, F)) { + pub fn prove( + &self, + pcs: &TensorMultilinearPCS, + transcript: &mut Transcript, + ) -> (SCPhase1Proof, (F, F, F)) { let num_vars = (self.Az_evals.len() as f64).log2() as usize; let mut round_polys = Vec::>::with_capacity(num_vars - 1); + // We implement the zero-knowledge sumcheck protocol + // described in Section 4.1 https://eprint.iacr.org/2019/317.pdf + let mut rng = rand::thread_rng(); - // Sample a blinding polynomial g(x_1, ..., x_m) of degree 3 + // Sample a blinding polynomial g(x_1, ..., x_m) let random_evals = (0..2usize.pow(num_vars as u32)) .map(|_| F::random(&mut rng)) .collect::>(); let blinder_poly_sum = random_evals.iter().fold(F::ZERO, |acc, x| acc + x); - let blinder_poly = MlPoly::new(random_evals); + let blinder_poly = SparseMLPoly::from_dense(random_evals); + + let blinder_poly_comm = pcs.commit(&blinder_poly); transcript.append_fe(&blinder_poly_sum); + transcript.append_bytes(&blinder_poly_comm.committed_tree.root); + let rho = transcript.challenge_fe(); // Compute the sum of g(x_1, ... x_m) over the boolean hypercube @@ -60,7 +70,11 @@ impl SumCheckPhase1 { let mut A_table = self.Az_evals.clone(); let mut B_table = self.Bz_evals.clone(); let mut C_table = self.Cz_evals.clone(); - let mut blinder_table = blinder_poly.evals.clone(); + let mut blinder_table = blinder_poly + .evals + .iter() + .map(|(_, x)| *x) + .collect::>(); let mut eq_table = self.bound_eq_poly.evals(); let zero = F::ZERO; @@ -95,6 +109,7 @@ impl SumCheckPhase1 { blinder_table[b] + (blinder_table[b + high_index] - blinder_table[b]) * r_i; } + // TODO: Maybe send the evaluations to the verifier? let round_poly = UniPoly::interpolate(&evals); round_polys.push(round_poly); @@ -104,16 +119,17 @@ impl SumCheckPhase1 { let v_B = B_table[0]; let v_C = C_table[0]; - let rx = self.challenge.clone(); - let blinder_poly_eval_claim = blinder_poly.eval(&rx); - // Prove the evaluation of the blinder polynomial at rx. + let mut rx_rev = self.challenge.clone(); + rx_rev.reverse(); + let blinder_poly_eval_proof = + pcs.open(&blinder_poly_comm, &blinder_poly, &rx_rev, transcript); ( SCPhase1Proof { blinder_poly_sum, round_polys, - blinder_poly_eval_claim, + blinder_poly_eval_proof, }, (v_A, v_B, v_C), ) @@ -125,6 +141,8 @@ impl SumCheckPhase1 { let zero = F::ZERO; let one = F::ONE; + println!("v phase 1 rho = {:?}", rho); + // target = 0 + rho * blinder_poly_sum let mut target = rho * proof.blinder_poly_sum; for (i, round_poly) in proof.round_polys.iter().enumerate() { diff --git a/shockwave_plus/src/sumcheck/sc_phase_2.rs b/shockwave_plus/src/sumcheck/sc_phase_2.rs index a5c6e70..8a85f17 100644 --- a/shockwave_plus/src/sumcheck/sc_phase_2.rs +++ b/shockwave_plus/src/sumcheck/sc_phase_2.rs @@ -1,15 +1,14 @@ -use crate::polynomial::ml_poly::MlPoly; use crate::r1cs::r1cs::Matrix; use crate::sumcheck::unipoly::UniPoly; use crate::FieldExt; use serde::{Deserialize, Serialize}; -use tensor_pcs::{EqPoly, Transcript}; +use tensor_pcs::{EqPoly, SparseMLPoly, TensorMLOpening, TensorMultilinearPCS, Transcript}; #[derive(Serialize, Deserialize)] pub struct SCPhase2Proof { pub round_polys: Vec>, pub blinder_poly_sum: F, - pub blinder_poly_eval_claim: F, + pub blinder_poly_eval_proof: TensorMLOpening, } pub struct SumCheckPhase2 { @@ -43,7 +42,11 @@ impl SumCheckPhase2 { } } - pub fn prove(&self, transcript: &mut Transcript) -> SCPhase2Proof { + pub fn prove( + &self, + pcs: &TensorMultilinearPCS, + transcript: &mut Transcript, + ) -> SCPhase2Proof { let r_A = self.r[0]; let r_B = self.r[1]; let r_C = self.r[2]; @@ -72,9 +75,12 @@ impl SumCheckPhase2 { .map(|_| F::random(&mut rng)) .collect::>(); let blinder_poly_sum = random_evals.iter().fold(F::ZERO, |acc, x| acc + x); - let blinder_poly = MlPoly::new(random_evals); + let blinder_poly = SparseMLPoly::from_dense(random_evals); + let blinder_poly_comm = pcs.commit(&blinder_poly); transcript.append_fe(&blinder_poly_sum); + transcript.append_bytes(&blinder_poly_comm.committed_tree.root); + let rho = transcript.challenge_fe(); let mut round_polys: Vec> = Vec::>::with_capacity(num_vars); @@ -83,7 +89,11 @@ impl SumCheckPhase2 { let mut B_table = B_evals.clone(); let mut C_table = C_evals.clone(); let mut Z_table = self.Z_evals.clone(); - let mut blinder_table = blinder_poly.evals.clone(); + let mut blinder_table = blinder_poly + .evals + .iter() + .map(|(_, x)| *x) + .collect::>(); let zero = F::ZERO; let one = F::ONE; @@ -119,12 +129,15 @@ impl SumCheckPhase2 { } let mut r_y_rev = self.challenge.clone(); - let blinder_poly_eval_claim = blinder_poly.eval(&r_y_rev); + r_y_rev.reverse(); + + let blinder_poly_eval_proof = + pcs.open(&blinder_poly_comm, &blinder_poly, &r_y_rev, transcript); SCPhase2Proof { round_polys, - blinder_poly_eval_claim, blinder_poly_sum, + blinder_poly_eval_proof, } } diff --git a/tensor_pcs/src/fft.rs b/tensor_pcs/src/fft.rs index d47665b..c2a4f66 100644 --- a/tensor_pcs/src/fft.rs +++ b/tensor_pcs/src/fft.rs @@ -33,8 +33,8 @@ where let fft_e = fft(&L, &domain_squared); let fft_o = fft(&R, &domain_squared); - let mut evals_L = vec![]; - let mut evals_R = vec![]; + let mut evals_L = Vec::with_capacity(coeffs.len() / 2); + let mut evals_R = Vec::with_capacity(coeffs.len() / 2); for i in 0..(coeffs.len() / 2) { // We can use the previous evaluations to create a list of evaluations // of the domain diff --git a/tensor_pcs/src/tensor_pcs.rs b/tensor_pcs/src/tensor_pcs.rs index 4053f58..4db6036 100644 --- a/tensor_pcs/src/tensor_pcs.rs +++ b/tensor_pcs/src/tensor_pcs.rs @@ -46,7 +46,7 @@ pub struct TensorMLOpening { pub base_opening: BaseOpening, pub test_query_leaves: Vec>, pub eval_query_leaves: Vec>, - u_hat_comm: [u8; 32], + pub u_hat_comm: [u8; 32], pub test_u_prime: Vec, pub test_r_prime: Vec, pub eval_r_prime: Vec, @@ -76,8 +76,6 @@ impl TensorMultilinearPCS { let num_rows = self.config.num_rows(); debug_assert_eq!(poly.num_vars, point.len()); - transcript.append_bytes(&u_hat_comm.committed_tree.root()); - // ######################################## // Testing phase // Prove the consistency between the random linear combination of the evaluation tensor (u_prime) @@ -152,30 +150,20 @@ impl TensorMultilinearPCS { } impl TensorMultilinearPCS { - pub fn verify( - &self, - opening: &TensorMLOpening, - commitment: &[u8; 32], - transcript: &mut Transcript, - ) { + pub fn verify(&self, opening: &TensorMLOpening, transcript: &mut Transcript) { let num_rows = self.config.num_rows(); let num_cols = self.config.num_cols(); - let u_hat_comm = opening.u_hat_comm; - transcript.append_bytes(&u_hat_comm); - - assert_eq!(&u_hat_comm, commitment); - // Verify the base opening - let base_opening = &opening.base_opening; - base_opening.verify(u_hat_comm); + base_opening.verify(opening.u_hat_comm); // ######################################## // Verify test phase // ######################################## let r_u = transcript.challenge_vec(num_rows); + println!("r_u = {:?}", r_u); let test_u_prime_rs_codeword = self .rs_encode(&opening.test_u_prime) @@ -375,14 +363,12 @@ mod tests { .collect::>(); let mut prover_transcript = Transcript::::new(b"test"); + prover_transcript.append_bytes(&comm.committed_tree.root); let opening = pcs.open(&comm, &ml_poly, &open_at, &mut prover_transcript); let mut verifier_transcript = Transcript::::new(b"test"); - pcs.verify( - &opening, - &comm.committed_tree.root(), - &mut verifier_transcript, - ); + verifier_transcript.append_bytes(&comm.committed_tree.root); + pcs.verify(&opening, &mut verifier_transcript); } fn config_base(ml_poly: &SparseMLPoly) -> TensorRSMultilinearPCSConfig {