diff --git a/shockwave_plus/benches/prove.rs b/shockwave_plus/benches/prove.rs index 1b962ac..b95417a 100644 --- a/shockwave_plus/benches/prove.rs +++ b/shockwave_plus/benches/prove.rs @@ -9,12 +9,13 @@ fn shockwave_plus_bench(c: &mut Criterion) { type F = halo2curves::secp256k1::Fp; for exp in [12, 15, 18] { - let num_vars = 2usize.pow(exp); + let num_cons = 2usize.pow(exp as u32); let num_input = 3; + let num_vars = num_cons - num_input; let (r1cs, witness) = R1CS::::produce_synthetic_r1cs(num_vars, num_input); - let mut group = c.benchmark_group(format!("ShockwavePlus num_cons: {}", r1cs.num_cons)); + let mut group = c.benchmark_group(format!("ShockwavePlus num_cons: {}", r1cs.num_cons())); let l = 319; let num_cols = det_num_cols(r1cs.z_len(), l); diff --git a/shockwave_plus/src/lib.rs b/shockwave_plus/src/lib.rs index 8f8ac8f..25afffb 100644 --- a/shockwave_plus/src/lib.rs +++ b/shockwave_plus/src/lib.rs @@ -6,11 +6,13 @@ mod sumcheck; use ark_std::{end_timer, start_timer}; use serde::{Deserialize, Serialize}; use sumcheck::{SCPhase1Proof, SCPhase2Proof, SumCheckPhase1, SumCheckPhase2}; -use tensor_pcs::{ecfft::GoodCurve, *}; +use tensor_pcs::{ecfft::GoodCurve, MlPoly, *}; // Exports pub use r1cs::R1CS; +use crate::polynomial::sparse_ml_poly::SparseMLPoly; + #[derive(Serialize, Deserialize)] pub struct PartialSpartanProof { pub z_comm: [u8; 32], @@ -69,13 +71,17 @@ impl ShockwavePlus { r1cs_input: &[F], transcript: &mut Transcript, ) -> (PartialSpartanProof, Vec) { - // Compute the multilinear extension of the witness - let witness_poly = SparseMLPoly::from_dense(r1cs_witness.to_vec()); + // Multilinear extension requires the number of evaluations + // to be a power of two to uniquely determine the polynomial + let mut padded_r1cs_witness = r1cs_witness.to_vec(); + padded_r1cs_witness.resize(padded_r1cs_witness.len().next_power_of_two(), F::ZERO); + let witness_poly = MlPoly::new(padded_r1cs_witness.clone()); + let Z = R1CS::construct_z(r1cs_witness, r1cs_input); // Commit the witness polynomial let comm_witness_timer = start_timer!(|| "Commit witness"); - let committed_witness = self.pcs.commit(&witness_poly); + let committed_witness = self.pcs.commit(&padded_r1cs_witness); let witness_comm = committed_witness.committed_tree.root; end_timer!(comm_witness_timer); @@ -89,9 +95,13 @@ impl ShockwavePlus { let m = (self.r1cs.z_len() as f64).log2() as usize; let tau = transcript.challenge_vec(m); - let Az_poly = self.r1cs.A.mul_vector(&Z); - let Bz_poly = self.r1cs.B.mul_vector(&Z); - let Cz_poly = self.r1cs.C.mul_vector(&Z); + let mut Az_poly = self.r1cs.A.mul_vector(&Z); + let mut Bz_poly = self.r1cs.B.mul_vector(&Z); + let mut Cz_poly = self.r1cs.C.mul_vector(&Z); + + Az_poly.resize(Z.len(), F::ZERO); + Bz_poly.resize(Z.len(), F::ZERO); + Cz_poly.resize(Z.len(), F::ZERO); // Prove that the // Q(t) = \sum_{x \in {0, 1}^m} (Az_poly(x) * Bz_poly(x) - Cz_poly(x)) eq(t, x) @@ -109,7 +119,6 @@ impl ShockwavePlus { tau.clone(), rx.clone(), ); - let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs, transcript); end_timer!(sc_phase_1_timer); @@ -139,9 +148,13 @@ impl ShockwavePlus { let z_open_timer = start_timer!(|| "Open witness poly"); // Prove the evaluation of the polynomial Z(y) at ry - let z_eval_proof = self - .pcs - .open(&committed_witness, &witness_poly, &ry[1..], transcript); + let z_eval_proof = self.pcs.open( + &committed_witness, + &padded_r1cs_witness, + &ry[1..], + witness_poly.eval(&ry[1..]), + transcript, + ); end_timer!(z_open_timer); // Prove the evaluation of the polynomials A(y), B(y), C(y) at ry @@ -262,7 +275,7 @@ mod tests { fn test_shockwave_plus() { type F = halo2curves::secp256k1::Fp; - let num_vars = 2usize.pow(6); + let num_vars = 10; let num_input = 3; let l = 2; @@ -277,7 +290,7 @@ mod tests { let (partial_proof, _) = ShockwavePlus.prove(&witness, &r1cs.public_input, &mut prover_transcript); - let mut verifier_transcript = Transcript::new(b"bench"); - ShockwavePlus.verify_partial(&partial_proof, &mut verifier_transcript); + // let mut verifier_transcript = Transcript::new(b"bench"); + // ShockwavePlus.verify_partial(&partial_proof, &mut verifier_transcript); } } diff --git a/shockwave_plus/src/polynomial/mod.rs b/shockwave_plus/src/polynomial/mod.rs index 423ae39..73ea217 100644 --- a/shockwave_plus/src/polynomial/mod.rs +++ b/shockwave_plus/src/polynomial/mod.rs @@ -1 +1 @@ -pub mod ml_poly; +pub mod sparse_ml_poly; diff --git a/tensor_pcs/src/polynomial/sparse_ml_poly.rs b/shockwave_plus/src/polynomial/sparse_ml_poly.rs similarity index 97% rename from tensor_pcs/src/polynomial/sparse_ml_poly.rs rename to shockwave_plus/src/polynomial/sparse_ml_poly.rs index 5b59e33..0ed9861 100644 --- a/tensor_pcs/src/polynomial/sparse_ml_poly.rs +++ b/shockwave_plus/src/polynomial/sparse_ml_poly.rs @@ -16,6 +16,7 @@ impl SparseMLPoly { } pub fn from_dense(dense_evals: Vec) -> Self { + assert!(dense_evals.len().is_power_of_two()); let sparse_evals = dense_evals .iter() .enumerate() diff --git a/shockwave_plus/src/r1cs/r1cs.rs b/shockwave_plus/src/r1cs/r1cs.rs index 50c69b4..5564a73 100644 --- a/shockwave_plus/src/r1cs/r1cs.rs +++ b/shockwave_plus/src/r1cs/r1cs.rs @@ -1,7 +1,7 @@ +use crate::polynomial::sparse_ml_poly::SparseMLPoly; use crate::FieldExt; -use tensor_pcs::SparseMLPoly; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct SparseMatrixEntry { pub row: usize, pub col: usize, @@ -52,7 +52,10 @@ where let val = entries[i].val; evals.push(((row * num_cols) + col, val)); } - let ml_poly_num_vars = ((self.num_cols * self.num_rows) as f64).log2() as usize; + let ml_poly_num_vars = ((self.num_cols.next_power_of_two() + * self.num_rows.next_power_of_two()) as f64) + .log2() as usize; + let ml_poly = SparseMLPoly::new(evals, ml_poly_num_vars); ml_poly } @@ -131,7 +134,6 @@ where pub B: Matrix, pub C: Matrix, pub public_input: Vec, - pub num_cons: usize, pub num_vars: usize, pub num_input: usize, } @@ -149,6 +151,10 @@ where result } + pub fn num_cons(&self) -> usize { + self.A.entries.len() + } + pub fn z_len(&self) -> usize { ((self.num_vars.next_power_of_two() + 1) + self.num_input).next_power_of_two() } @@ -185,11 +191,47 @@ where let mut B_entries: Vec> = vec![]; let mut C_entries: Vec> = vec![]; - let num_cons = z.len(); - for i in 0..num_cons { - let A_col = i % num_cons; - let B_col = (i + 1) % num_cons; - let C_col = (i + 2) % num_cons; + // Constrain the variables + for i in 0..num_vars { + let A_col = i % num_vars; + let B_col = (i + 1) % num_vars; + let C_col = (i + 2) % num_vars; + + // For the i'th constraint, + // add the value 1 at the (i % num_vars)th column of A, B. + // Compute the corresponding C_column value so that A_i * B_i = C_i + // we apply multiplication since the Hadamard product is computed for Az ・ Bz, + + // We only _enable_ a single variable in each constraint. + let AB = if z[C_col] == F::ZERO { F::ZERO } else { F::ONE }; + + A_entries.push(SparseMatrixEntry { + row: i, + col: A_col, + val: AB, + }); + B_entries.push(SparseMatrixEntry { + row: i, + col: B_col, + val: AB, + }); + C_entries.push(SparseMatrixEntry { + row: i, + col: C_col, + val: if z[C_col] == F::ZERO { + F::ZERO + } else { + (z[A_col] * z[B_col]) * z[C_col].invert().unwrap() + }, + }); + } + + // Constrain the public inputs + let input_index_start = num_vars.next_power_of_two() + 1; + for i in input_index_start..(input_index_start + num_input) { + let A_col = i; + let B_col = (i + 1) % input_index_start + num_input; + let C_col = (i + 2) % input_index_start + num_input; // For the i'th constraint, // add the value 1 at the (i % num_vars)th column of A, B. @@ -221,7 +263,7 @@ where } let num_cols = z.len(); - let num_rows = num_cols; + let num_rows = z.len(); let A = Matrix::new(A_entries, num_cols, num_rows); let B = Matrix::new(B_entries, num_cols, num_rows); @@ -233,7 +275,6 @@ where B, C, public_input, - num_cons, num_vars, num_input, }, @@ -258,7 +299,7 @@ mod tests { use super::*; type F = halo2curves::secp256k1::Fp; - use crate::polynomial::ml_poly::MlPoly; + use tensor_pcs::MlPoly; // Returns a vector of vectors of length m, where each vector is a boolean vector (big endian) fn boolean_hypercube(m: usize) -> Vec> { @@ -286,6 +327,7 @@ mod tests { let num_vars = num_cons - num_input; let (r1cs, mut witness) = R1CS::::produce_synthetic_r1cs(num_vars, num_input); + assert_eq!(r1cs.num_cons(), num_cons); assert_eq!(witness.len(), num_vars); assert_eq!(r1cs.public_input.len(), num_input); diff --git a/shockwave_plus/src/sumcheck/sc_phase_1.rs b/shockwave_plus/src/sumcheck/sc_phase_1.rs index 2c9acb3..d627911 100644 --- a/shockwave_plus/src/sumcheck/sc_phase_1.rs +++ b/shockwave_plus/src/sumcheck/sc_phase_1.rs @@ -1,6 +1,6 @@ use crate::sumcheck::unipoly::UniPoly; use serde::{Deserialize, Serialize}; -use tensor_pcs::{EqPoly, SparseMLPoly, TensorMLOpening, TensorMultilinearPCS, Transcript}; +use tensor_pcs::{EqPoly, MlPoly, TensorMLOpening, TensorMultilinearPCS, Transcript}; use crate::FieldExt; @@ -50,13 +50,13 @@ impl SumCheckPhase1 { let mut rng = rand::thread_rng(); // Sample a blinding polynomial g(x_1, ..., x_m) - let random_evals = (0..2usize.pow(num_vars as u32)) + let blinder_poly_evals = (0..2usize.pow(num_vars as u32)) .map(|_| F::random(&mut rng)) .collect::>(); - let blinder_poly_sum = random_evals.iter().fold(F::ZERO, |acc, x| acc + x); - let blinder_poly = SparseMLPoly::from_dense(random_evals); + let blinder_poly = MlPoly::new(blinder_poly_evals.clone()); + let blinder_poly_sum = blinder_poly_evals.iter().fold(F::ZERO, |acc, x| acc + x); - let blinder_poly_comm = pcs.commit(&blinder_poly); + let blinder_poly_comm = pcs.commit(&blinder_poly_evals); transcript.append_fe(&blinder_poly_sum); transcript.append_bytes(&blinder_poly_comm.committed_tree.root); @@ -70,11 +70,7 @@ impl SumCheckPhase1 { let mut A_table = self.Az_evals.clone(); let mut B_table = self.Bz_evals.clone(); let mut C_table = self.Cz_evals.clone(); - let mut blinder_table = blinder_poly - .evals - .iter() - .map(|(_, x)| *x) - .collect::>(); + let mut blinder_table = blinder_poly_evals.clone(); let mut eq_table = self.bound_eq_poly.evals(); let zero = F::ZERO; @@ -122,8 +118,9 @@ impl SumCheckPhase1 { // Prove the evaluation of the blinder polynomial at rx. let blinder_poly_eval_proof = pcs.open( &blinder_poly_comm, - &blinder_poly, + &blinder_poly_evals, &self.challenge, + blinder_poly.eval(&self.challenge), transcript, ); diff --git a/shockwave_plus/src/sumcheck/sc_phase_2.rs b/shockwave_plus/src/sumcheck/sc_phase_2.rs index 7de8c48..29b23b7 100644 --- a/shockwave_plus/src/sumcheck/sc_phase_2.rs +++ b/shockwave_plus/src/sumcheck/sc_phase_2.rs @@ -2,7 +2,7 @@ use crate::r1cs::r1cs::Matrix; use crate::sumcheck::unipoly::UniPoly; use crate::FieldExt; use serde::{Deserialize, Serialize}; -use tensor_pcs::{EqPoly, SparseMLPoly, TensorMLOpening, TensorMultilinearPCS, Transcript}; +use tensor_pcs::{EqPoly, MlPoly, TensorMLOpening, TensorMultilinearPCS, Transcript}; #[derive(Serialize, Deserialize)] pub struct SCPhase2Proof { @@ -71,12 +71,12 @@ impl SumCheckPhase2 { let mut rng = rand::thread_rng(); // Sample a blinding polynomial g(x_1, ..., x_m) of degree 3 - let random_evals = (0..2usize.pow(num_vars as u32)) + let blinder_poly_evals = (0..2usize.pow(num_vars as u32)) .map(|_| F::random(&mut rng)) .collect::>(); - let blinder_poly_sum = random_evals.iter().fold(F::ZERO, |acc, x| acc + x); - let blinder_poly = SparseMLPoly::from_dense(random_evals); - let blinder_poly_comm = pcs.commit(&blinder_poly); + let blinder_poly_sum = blinder_poly_evals.iter().fold(F::ZERO, |acc, x| acc + x); + let blinder_poly = MlPoly::new(blinder_poly_evals.clone()); + let blinder_poly_comm = pcs.commit(&blinder_poly_evals); transcript.append_fe(&blinder_poly_sum); transcript.append_bytes(&blinder_poly_comm.committed_tree.root); @@ -89,11 +89,7 @@ impl SumCheckPhase2 { let mut B_table = B_evals.clone(); let mut C_table = C_evals.clone(); let mut Z_table = self.Z_evals.clone(); - let mut blinder_table = blinder_poly - .evals - .iter() - .map(|(_, x)| *x) - .collect::>(); + let mut blinder_table = blinder_poly_evals.clone(); let zero = F::ZERO; let one = F::ONE; @@ -130,7 +126,13 @@ impl SumCheckPhase2 { let ry = self.challenge.clone(); - let blinder_poly_eval_proof = pcs.open(&blinder_poly_comm, &blinder_poly, &ry, transcript); + let blinder_poly_eval_proof = pcs.open( + &blinder_poly_comm, + &blinder_poly_evals, + &ry, + blinder_poly.eval(&ry), + transcript, + ); SCPhase2Proof { round_polys, diff --git a/tensor_pcs/benches/prove.rs b/tensor_pcs/benches/prove.rs index 98b80ed..4c04257 100644 --- a/tensor_pcs/benches/prove.rs +++ b/tensor_pcs/benches/prove.rs @@ -1,17 +1,16 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use tensor_pcs::{ - rs_config, FieldExt, SparseMLPoly, TensorMultilinearPCS, TensorRSMultilinearPCSConfig, - Transcript, + rs_config, FieldExt, MlPoly, TensorMultilinearPCS, TensorRSMultilinearPCSConfig, Transcript, }; -fn poly(num_vars: usize) -> SparseMLPoly { +fn poly(num_vars: usize) -> MlPoly { let num_entries: usize = 2usize.pow(num_vars as u32); let evals = (0..num_entries) - .map(|i| (i, F::from(i as u64))) - .collect::>(); + .map(|i| F::from(i as u64)) + .collect::>(); - let ml_poly = SparseMLPoly::new(evals, num_vars); + let ml_poly = MlPoly::new(evals); ml_poly } @@ -32,10 +31,13 @@ fn pcs_fft_bench(c: &mut Criterion) { let num_vars = 13; let ml_poly = poly(num_vars); + let ml_poly_evals = ml_poly.evals.clone(); let open_at = (0..ml_poly.num_vars) .map(|i| F::from(i as u64)) .collect::>(); + let y = ml_poly.eval(&open_at); + let mut config = config_base(); config.fft_domain = Some(rs_config::smooth::gen_config::( config.num_cols(ml_poly.evals.len()), @@ -47,8 +49,8 @@ fn pcs_fft_bench(c: &mut Criterion) { let pcs = TensorMultilinearPCS::::new(config.clone()); let mut transcript = Transcript::new(b"bench"); - let comm = pcs.commit(&black_box(ml_poly.clone())); - pcs.open(&comm, &ml_poly, &open_at, &mut transcript); + let comm = pcs.commit(black_box(&ml_poly_evals)); + pcs.open(&comm, &ml_poly_evals, &open_at, y, &mut transcript); }) }); } @@ -58,10 +60,13 @@ fn pcs_ecfft_bench(c: &mut Criterion) { let num_vars = 13; let ml_poly = poly(num_vars); + let ml_poly_evals = ml_poly.evals.clone(); let open_at = (0..ml_poly.num_vars) .map(|i| F::from(i as u64)) .collect::>(); + let y = ml_poly.eval(&open_at); + let mut config = config_base(); config.ecfft_config = Some(rs_config::ecfft::gen_config::( config.num_cols(ml_poly.evals.len()), @@ -73,8 +78,8 @@ fn pcs_ecfft_bench(c: &mut Criterion) { let pcs = TensorMultilinearPCS::::new(config.clone()); let mut transcript = Transcript::new(b"bench"); - let comm = pcs.commit(&black_box(ml_poly.clone())); - pcs.open(&comm, &ml_poly, &open_at, &mut transcript); + let comm = pcs.commit(black_box(&ml_poly_evals)); + pcs.open(&comm, &ml_poly_evals, &open_at, y, &mut transcript); }) }); } @@ -86,6 +91,6 @@ fn set_duration() -> Criterion { criterion_group! { name = benches; config = set_duration(); - targets = pcs_ecfft_bench + targets = pcs_ecfft_bench, pcs_fft_bench } criterion_main!(benches); diff --git a/tensor_pcs/src/lib.rs b/tensor_pcs/src/lib.rs index 1ae9f56..a262f23 100644 --- a/tensor_pcs/src/lib.rs +++ b/tensor_pcs/src/lib.rs @@ -16,7 +16,7 @@ impl FieldExt for halo2curves::pasta::Fp {} pub use ecfft; pub use polynomial::eq_poly::EqPoly; -pub use polynomial::sparse_ml_poly::SparseMLPoly; +pub use polynomial::ml_poly::MlPoly; pub use tensor_rs_pcs::{TensorMLOpening, TensorMultilinearPCS, TensorRSMultilinearPCSConfig}; pub use transcript::{AppendToTranscript, Transcript}; -pub use utils::{det_num_cols, det_num_rows}; +pub use utils::{det_num_cols, det_num_rows, dot_prod}; diff --git a/shockwave_plus/src/polynomial/ml_poly.rs b/tensor_pcs/src/polynomial/ml_poly.rs similarity index 98% rename from shockwave_plus/src/polynomial/ml_poly.rs rename to tensor_pcs/src/polynomial/ml_poly.rs index 91a941f..4b3b4c9 100644 --- a/shockwave_plus/src/polynomial/ml_poly.rs +++ b/tensor_pcs/src/polynomial/ml_poly.rs @@ -1,4 +1,4 @@ -use tensor_pcs::EqPoly; +use crate::polynomial::eq_poly::EqPoly; use crate::FieldExt; diff --git a/tensor_pcs/src/polynomial/mod.rs b/tensor_pcs/src/polynomial/mod.rs index 73643eb..475a3ec 100644 --- a/tensor_pcs/src/polynomial/mod.rs +++ b/tensor_pcs/src/polynomial/mod.rs @@ -1,2 +1,2 @@ pub mod eq_poly; -pub mod sparse_ml_poly; +pub mod ml_poly; diff --git a/tensor_pcs/src/tensor_rs_pcs.rs b/tensor_pcs/src/tensor_rs_pcs.rs index bf530bd..0a7ddbb 100644 --- a/tensor_pcs/src/tensor_rs_pcs.rs +++ b/tensor_pcs/src/tensor_rs_pcs.rs @@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize}; use crate::fft::fft; use crate::polynomial::eq_poly::EqPoly; -use crate::polynomial::sparse_ml_poly::SparseMLPoly; use crate::tensor_code::TensorCode; use crate::transcript::Transcript; use crate::utils::{det_num_cols, det_num_rows, dot_prod, hash_all, rlc_rows, sample_indices}; @@ -57,32 +56,32 @@ impl TensorMultilinearPCS { Self { config } } - pub fn commit(&self, poly: &SparseMLPoly) -> CommittedTensorCode { + pub fn commit(&self, ml_poly_evals: &[F]) -> CommittedTensorCode { // Merkle commit to the evaluations of the polynomial - let tensor_code = self.encode_zk(poly); - let tree = tensor_code.commit( - self.config.num_cols(poly.num_entries()), - self.config.num_rows(poly.num_entries()), - ); + let n = ml_poly_evals.len(); + assert!(n.is_power_of_two()); + let tensor_code = self.encode_zk(ml_poly_evals); + let tree = tensor_code.commit(self.config.num_cols(n), self.config.num_rows(n)); tree } pub fn open( &self, u_hat_comm: &CommittedTensorCode, - poly: &SparseMLPoly, + // TODO: Remove poly and use u_hat_comm + ml_poly_evals: &[F], point: &[F], + eval: F, transcript: &mut Transcript, ) -> TensorMLOpening { - let num_cols = self.config.num_cols(poly.num_entries()); - let num_rows = self.config.num_rows(poly.num_entries()); - debug_assert_eq!(poly.num_vars, point.len()); + let n = ml_poly_evals.len(); + assert!(n.is_power_of_two()); + let num_vars = (n as f64).log2() as usize; - let mut padded_evals = poly.evals.clone(); - padded_evals.resize( - num_cols * num_rows, - (2usize.pow(poly.num_vars as u32), F::ZERO), - ); + let num_cols = self.config.num_cols(n); + let num_rows = self.config.num_rows(n); + + debug_assert_eq!(num_vars, point.len()); // ######################################## // Testing phase @@ -94,12 +93,7 @@ impl TensorMultilinearPCS { let r_u = transcript.challenge_vec(num_rows); let u = (0..num_rows) - .map(|i| { - padded_evals[(i * num_cols)..((i + 1) * num_cols)] - .iter() - .map(|entry| entry.1) - .collect::>() - }) + .map(|i| ml_poly_evals[(i * num_cols)..((i + 1) * num_cols)].to_vec()) .collect::>>(); // Random linear combination of the rows of the polynomial in a tensor structure @@ -140,7 +134,7 @@ impl TensorMultilinearPCS { TensorMLOpening { x: point.to_vec(), - y: poly.eval(&point), + y: eval, eval_query_leaves: eval_queries, test_query_leaves: test_queries, u_hat_comm: u_hat_comm.committed_tree.root(), @@ -151,7 +145,7 @@ impl TensorMultilinearPCS { base_opening: BaseOpening { hashes: u_hat_comm.committed_tree.column_roots.clone(), }, - poly_num_vars: poly.num_vars, + poly_num_vars: num_vars, } } } @@ -341,25 +335,16 @@ impl TensorMultilinearPCS { u_hat_openings } - fn encode_zk(&self, poly: &SparseMLPoly) -> TensorCode { - let num_rows = self.config.num_rows(poly.num_entries()); - let num_cols = self.config.num_cols(poly.num_entries()); + fn encode_zk(&self, ml_poly_evals: &[F]) -> TensorCode { + let n = ml_poly_evals.len(); + assert!(n.is_power_of_two()); - // Pad the sparse evaluations with zeros - let mut evals = poly.evals.clone(); - evals.resize( - num_cols * num_rows, - (2usize.pow(poly.num_vars as u32), F::ZERO), - ); - debug_assert_eq!(evals.len(), num_cols * num_rows); + let num_rows = self.config.num_rows(n); + let num_cols = self.config.num_cols(n); + debug_assert_eq!(n, num_cols * num_rows); let codewords = (0..num_rows) - .map(|i| { - evals[i * num_cols..(i + 1) * num_cols] - .iter() - .map(|entry| entry.1) - .collect::>() - }) + .map(|i| &ml_poly_evals[i * num_cols..(i + 1) * num_cols]) .map(|row| self.split_encode(&row)) .collect::>>(); @@ -369,42 +354,44 @@ impl TensorMultilinearPCS { #[cfg(test)] mod tests { - use ::ecfft::find_coset_offset; - use super::*; + use crate::polynomial::ml_poly::MlPoly; use crate::rs_config::{ecfft, good_curves::secp256k1::secp256k1_good_curve, naive, smooth}; const TEST_NUM_VARS: usize = 8; const TEST_L: usize = 10; - fn test_poly() -> SparseMLPoly { + fn test_poly_evals() -> MlPoly { let num_entries: usize = 2usize.pow(TEST_NUM_VARS as u32); let evals = (0..num_entries) - .map(|i| (i, F::from(i as u64))) - .collect::>(); + .map(|i| F::from((i + 1) as u64)) + .collect::>(); - let ml_poly = SparseMLPoly::new(evals, TEST_NUM_VARS); - ml_poly + MlPoly::new(evals) } - fn prove_and_verify(ml_poly: SparseMLPoly, pcs: TensorMultilinearPCS) { - let comm = pcs.commit(&ml_poly); + fn prove_and_verify(ml_poly: &MlPoly, pcs: TensorMultilinearPCS) { + let ml_poly_evals = &ml_poly.evals; - let open_at = (0..ml_poly.num_vars) + let comm = pcs.commit(ml_poly_evals); + + let ml_poly_num_vars = (ml_poly_evals.len() as f64).log2() as usize; + let open_at = (0..ml_poly_num_vars) .map(|i| F::from(i as u64)) .collect::>(); + let y = ml_poly.eval(&open_at); let mut prover_transcript = Transcript::::new(b"test"); prover_transcript.append_bytes(&comm.committed_tree.root); - let opening = pcs.open(&comm, &ml_poly, &open_at, &mut prover_transcript); + let opening = pcs.open(&comm, ml_poly_evals, &open_at, y, &mut prover_transcript); let mut verifier_transcript = Transcript::::new(b"test"); verifier_transcript.append_bytes(&comm.committed_tree.root); pcs.verify(&opening, &mut verifier_transcript); } - fn config_base(ml_poly: &SparseMLPoly) -> TensorRSMultilinearPCSConfig { + fn config_base() -> TensorRSMultilinearPCSConfig { let expansion_factor = 2; TensorRSMultilinearPCSConfig:: { @@ -420,23 +407,27 @@ mod tests { fn test_tensor_pcs_fft() { type F = halo2curves::pasta::Fp; // FFT config - let ml_poly = test_poly(); - let mut config = config_base(&ml_poly); - config.fft_domain = Some(smooth::gen_config(config.num_cols(ml_poly.num_entries()))); + let ml_poly = test_poly_evals(); + let mut config = config_base(); + + // The test polynomial has 2^k non-zero entries + let num_entries = ml_poly.evals.len(); + config.fft_domain = Some(smooth::gen_config(config.num_cols(num_entries))); // Test FFT PCS let tensor_pcs_fft = TensorMultilinearPCS::::new(config); - prove_and_verify(ml_poly, tensor_pcs_fft); + prove_and_verify(&ml_poly, tensor_pcs_fft); } #[test] fn test_tensor_pcs_ecfft() { type F = halo2curves::secp256k1::Fp; - let ml_poly = test_poly(); + let ml_poly = test_poly_evals(); - let mut config = config_base(&ml_poly); + let mut config = config_base(); - let num_cols = config.num_cols(ml_poly.num_entries()); + let n = ml_poly.evals.len(); + let num_cols = config.num_cols(n); let k = ((num_cols * config.expansion_factor).next_power_of_two() as f64).log2() as usize; let (curve, coset_offset) = secp256k1_good_curve(k); @@ -444,21 +435,22 @@ mod tests { // Test FFT PCS let tensor_pcs_ecf = TensorMultilinearPCS::::new(config); - prove_and_verify(ml_poly, tensor_pcs_ecf); + prove_and_verify(&ml_poly, tensor_pcs_ecf); } #[test] fn test_tensor_pcs_naive() { type F = halo2curves::secp256k1::Fp; // FFT config - let ml_poly = test_poly(); + let ml_poly = test_poly_evals(); + let n = ml_poly.evals.len(); // Naive config - let mut config = config_base(&ml_poly); - config.domain_powers = Some(naive::gen_config(config.num_cols(ml_poly.num_entries()))); + let mut config = config_base(); + config.domain_powers = Some(naive::gen_config(config.num_cols(n))); // Test FFT PCS let tensor_pcs_naive = TensorMultilinearPCS::::new(config); - prove_and_verify(ml_poly, tensor_pcs_naive); + prove_and_verify(&ml_poly, tensor_pcs_naive); } } diff --git a/tensor_pcs/src/utils.rs b/tensor_pcs/src/utils.rs index 3c4e460..0e7edfa 100644 --- a/tensor_pcs/src/utils.rs +++ b/tensor_pcs/src/utils.rs @@ -76,6 +76,7 @@ pub fn sample_indices( } pub fn det_num_cols(num_entries: usize, l: usize) -> usize { + assert!(num_entries.is_power_of_two()); let num_entries_sqrt = (num_entries as f64).sqrt() as usize; // The number of columns must be a power of two // to tensor-query the polynomial evaluation @@ -84,6 +85,7 @@ pub fn det_num_cols(num_entries: usize, l: usize) -> usize { } pub fn det_num_rows(num_entries: usize, l: usize) -> usize { + assert!(num_entries.is_power_of_two()); // The number of rows must be a power of two // to tensor-query the polynomial evaluation let num_rows = (num_entries / det_num_cols(num_entries, l)).next_power_of_two();