use crate::rs_config::ecfft::ECFFTConfig; use crate::tree::BaseOpening; use crate::FieldExt; use ecfft::extend; use serde::{Deserialize, Serialize}; use crate::fft::fft; use crate::polynomial::eq_poly::EqPoly; use crate::tensor_code::TensorCode; use crate::transcript::Transcript; use crate::utils::{det_num_cols, det_num_rows, dot_prod, hash_all, rlc_rows, sample_indices}; use super::tensor_code::CommittedTensorCode; #[derive(Clone)] pub struct TensorRSMultilinearPCSConfig { pub expansion_factor: usize, pub domain_powers: Option>>, pub fft_domain: Option>, pub ecfft_config: Option>, pub l: usize, } impl TensorRSMultilinearPCSConfig { pub fn num_cols(&self, num_entries: usize) -> usize { det_num_cols(num_entries, self.l) } pub fn num_rows(&self, num_entries: usize) -> usize { det_num_rows(num_entries, self.l) } } #[derive(Clone)] pub struct TensorMultilinearPCS { config: TensorRSMultilinearPCSConfig, } #[derive(Clone, Serialize, Deserialize)] pub struct TensorMLOpening { pub x: Vec, pub y: F, pub base_opening: BaseOpening, pub test_query_leaves: Vec>, pub eval_query_leaves: Vec>, pub u_hat_comm: [u8; 32], pub test_u_prime: Vec, pub test_r_prime: Vec, pub eval_r_prime: Vec, pub eval_u_prime: Vec, pub poly_num_vars: usize, } impl TensorMultilinearPCS { pub fn new(config: TensorRSMultilinearPCSConfig) -> Self { Self { config } } pub fn commit(&self, ml_poly_evals: &[F]) -> CommittedTensorCode { // Merkle commit to the evaluations of the polynomial let n = ml_poly_evals.len(); assert!(n.is_power_of_two()); let tensor_code = self.encode_zk(ml_poly_evals); let tree = tensor_code.commit(self.config.num_cols(n), self.config.num_rows(n)); tree } pub fn open( &self, u_hat_comm: &CommittedTensorCode, // TODO: Remove poly and use u_hat_comm ml_poly_evals: &[F], point: &[F], eval: F, transcript: &mut Transcript, ) -> TensorMLOpening { let n = ml_poly_evals.len(); assert!(n.is_power_of_two()); let num_vars = (n as f64).log2() as usize; let num_cols = self.config.num_cols(n); let num_rows = self.config.num_rows(n); debug_assert_eq!(num_vars, point.len()); // ######################################## // Testing phase // Prove the consistency between the random linear combination of the evaluation tensor (u_prime) // and the tensor codeword (u_hat) // ######################################## // Derive the challenge vector; let r_u = transcript.challenge_vec(num_rows); let u = (0..num_rows) .map(|i| ml_poly_evals[(i * num_cols)..((i + 1) * num_cols)].to_vec()) .collect::>>(); // Random linear combination of the rows of the polynomial in a tensor structure let test_u_prime = rlc_rows(u.clone(), &r_u); // Random linear combination of the blinder let blinder = u_hat_comm .tensor_codeword .0 .iter() .map(|row| row[(row.len() / 2)..].to_vec()) .collect::>>(); debug_assert_eq!(blinder[0].len(), u_hat_comm.tensor_codeword.0[0].len() / 2); let test_r_prime = rlc_rows(blinder.clone(), &r_u); let num_indices = self.config.l; let indices = sample_indices(num_indices, num_cols * 2, transcript); let test_queries = self.test_phase(&indices, &u_hat_comm); // ######################################## // Evaluation phase // Prove the consistency // ######################################## // Get the evaluation point let log2_num_rows = (num_rows as f64).log2() as usize; let q1 = EqPoly::new(point[0..log2_num_rows].to_vec()).evals(); let eval_r_prime = rlc_rows(blinder, &q1); let eval_u_prime = rlc_rows(u.clone(), &q1); let eval_queries = self.test_phase(&indices, &u_hat_comm); TensorMLOpening { x: point.to_vec(), y: eval, eval_query_leaves: eval_queries, test_query_leaves: test_queries, u_hat_comm: u_hat_comm.committed_tree.root(), test_u_prime, test_r_prime, eval_r_prime, eval_u_prime, base_opening: BaseOpening { hashes: u_hat_comm.committed_tree.column_roots.clone(), }, poly_num_vars: num_vars, } } } impl TensorMultilinearPCS { pub fn verify(&self, opening: &TensorMLOpening, transcript: &mut Transcript) { let poly_num_entries = 2usize.pow(opening.poly_num_vars as u32); let num_rows = self.config.num_rows(poly_num_entries); let num_cols = self.config.num_cols(poly_num_entries); // Verify the base opening let base_opening = &opening.base_opening; base_opening.verify(opening.u_hat_comm); // ######################################## // Verify test phase // ######################################## let r_u = transcript.challenge_vec(num_rows); let test_u_prime_rs_codeword = self .rs_encode(&opening.test_u_prime) .iter() .zip(opening.test_r_prime.iter()) .map(|(c, r)| *c + *r) .collect::>(); let num_indices = self.config.l; let indices = sample_indices(num_indices, num_cols * 2, transcript); debug_assert_eq!(indices.len(), opening.test_query_leaves.len()); for (expected_index, leaves) in indices.iter().zip(opening.test_query_leaves.iter()) { // Verify that the hashes of the leaves equals the corresponding column root let leaf_bytes = leaves .iter() .map(|x| x.to_repr()) .collect::>(); let column_root = hash_all(&leaf_bytes); let expected_column_root = base_opening.hashes[*expected_index]; assert_eq!(column_root, expected_column_root); let mut sum = F::ZERO; for (leaf, r_i) in leaves.iter().zip(r_u.iter()) { sum += *r_i * *leaf; } assert_eq!(sum, test_u_prime_rs_codeword[*expected_index]); } // ######################################## // Verify evaluation phase // ######################################## let log2_num_rows = (num_rows as f64).log2() as usize; let q1 = EqPoly::new(opening.x[0..log2_num_rows].to_vec()).evals(); let q2 = EqPoly::new(opening.x[log2_num_rows..].to_vec()).evals(); let eval_u_prime_rs_codeword = self .rs_encode(&opening.eval_u_prime) .iter() .zip(opening.eval_r_prime.iter()) .map(|(c, r)| *c + *r) .collect::>(); debug_assert_eq!(q1.len(), opening.eval_query_leaves[0].len()); debug_assert_eq!(indices.len(), opening.test_query_leaves.len()); for (expected_index, leaves) in indices.iter().zip(opening.eval_query_leaves.iter()) { // TODO: Don't need to check the leaves again? // Verify that the hashes of the leaves equals the corresponding column root let leaf_bytes = leaves .iter() .map(|x| x.to_repr()) .collect::>(); let column_root = hash_all(&leaf_bytes); let expected_column_root = base_opening.hashes[*expected_index]; assert_eq!(column_root, expected_column_root); let mut sum = F::ZERO; for (leaf, q1_i) in leaves.iter().zip(q1.iter()) { sum += *q1_i * *leaf; } assert_eq!(sum, eval_u_prime_rs_codeword[*expected_index]); } let expected_eval = dot_prod(&opening.eval_u_prime, &q2); assert_eq!(expected_eval, opening.y); } fn split_encode(&self, message: &[F]) -> Vec { let codeword = self.rs_encode(message); let mut rng = rand::thread_rng(); let blinder = (0..codeword.len()) .map(|_| F::random(&mut rng)) .collect::>(); let mut randomized_codeword = codeword .iter() .zip(blinder.clone().iter()) .map(|(c, b)| *b + *c) .collect::>(); randomized_codeword.extend_from_slice(&blinder); debug_assert_eq!(randomized_codeword.len(), codeword.len() * 2); randomized_codeword } fn rs_encode(&self, message: &[F]) -> Vec { let mut padded_message = message.to_vec(); padded_message.resize(message.len().next_power_of_two(), F::ZERO); let codeword_len = padded_message.len() * self.config.expansion_factor; let codeword_len_log2 = (codeword_len as f64).log2() as usize; let codeword = if self.config.fft_domain.is_some() { // TODO: Resize the domain according to the message length let fft_domain = self.config.fft_domain.as_ref().unwrap(); let mut padded_coeffs = message.clone().to_vec(); padded_coeffs.resize(fft_domain.len(), F::ZERO); let codeword = fft(&padded_coeffs, &fft_domain); codeword } else if self.config.ecfft_config.is_some() { let mut ecfft_config = self.config.ecfft_config.clone().unwrap(); // Resize the domain to the correct size let config_domain_size = ecfft_config.domain.len(); assert!(config_domain_size >= codeword_len_log2 - 1); ecfft_config.domain = ecfft_config.domain[(config_domain_size - (codeword_len_log2 - 1))..].to_vec(); ecfft_config.matrices = ecfft_config.matrices[(config_domain_size - (codeword_len_log2 - 1))..].to_vec(); ecfft_config.inverse_matrices = ecfft_config.inverse_matrices [(config_domain_size - (codeword_len_log2 - 1))..] .to_vec(); assert_eq!( padded_message.len() * self.config.expansion_factor, ecfft_config.domain[0].len() ); let extended_evals = extend( &padded_message, &ecfft_config.domain, &ecfft_config.matrices, &ecfft_config.inverse_matrices, 0, ); let codeword = [message.to_vec(), extended_evals].concat(); codeword } else { // TODO: Resize the domain according to the message length let domain_powers = self.config.domain_powers.as_ref().unwrap(); assert_eq!(message.len(), domain_powers[0].len()); assert_eq!( message.len() * self.config.expansion_factor, domain_powers.len() ); let codeword = domain_powers .iter() .map(|powers| { message .iter() .zip(powers.iter()) .fold(F::ZERO, |acc, (m, p)| acc + *m * *p) }) .collect::>(); codeword }; codeword } fn test_phase(&self, indices: &[usize], u_hat_comm: &CommittedTensorCode) -> Vec> { // Query the columns of u_hat let num_indices = self.config.l; let u_hat_openings = indices .iter() .map(|index| u_hat_comm.query_column(*index)) .collect::>>(); debug_assert_eq!(u_hat_openings.len(), num_indices); u_hat_openings } fn encode_zk(&self, ml_poly_evals: &[F]) -> TensorCode { let n = ml_poly_evals.len(); assert!(n.is_power_of_two()); let num_rows = self.config.num_rows(n); let num_cols = self.config.num_cols(n); debug_assert_eq!(n, num_cols * num_rows); let codewords = (0..num_rows) .map(|i| &ml_poly_evals[i * num_cols..(i + 1) * num_cols]) .map(|row| self.split_encode(&row)) .collect::>>(); TensorCode(codewords) } } #[cfg(test)] mod tests { use super::*; use crate::polynomial::ml_poly::MlPoly; use crate::rs_config::{ecfft, good_curves::secp256k1::secp256k1_good_curve, naive, smooth}; const TEST_NUM_VARS: usize = 8; const TEST_L: usize = 10; fn test_poly_evals() -> MlPoly { let num_entries: usize = 2usize.pow(TEST_NUM_VARS as u32); let evals = (0..num_entries) .map(|i| F::from((i + 1) as u64)) .collect::>(); MlPoly::new(evals) } fn prove_and_verify(ml_poly: &MlPoly, pcs: TensorMultilinearPCS) { let ml_poly_evals = &ml_poly.evals; let comm = pcs.commit(ml_poly_evals); let ml_poly_num_vars = (ml_poly_evals.len() as f64).log2() as usize; let open_at = (0..ml_poly_num_vars) .map(|i| F::from(i as u64)) .collect::>(); let y = ml_poly.eval(&open_at); let mut prover_transcript = Transcript::::new(b"test"); prover_transcript.append_bytes(&comm.committed_tree.root); let opening = pcs.open(&comm, ml_poly_evals, &open_at, y, &mut prover_transcript); let mut verifier_transcript = Transcript::::new(b"test"); verifier_transcript.append_bytes(&comm.committed_tree.root); pcs.verify(&opening, &mut verifier_transcript); } fn config_base() -> TensorRSMultilinearPCSConfig { let expansion_factor = 2; TensorRSMultilinearPCSConfig:: { expansion_factor, domain_powers: None, fft_domain: None, ecfft_config: None, l: TEST_L, } } #[test] fn test_tensor_pcs_fft() { type F = halo2curves::pasta::Fp; // FFT config let ml_poly = test_poly_evals(); let mut config = config_base(); // The test polynomial has 2^k non-zero entries let num_entries = ml_poly.evals.len(); config.fft_domain = Some(smooth::gen_config(config.num_cols(num_entries))); // Test FFT PCS let tensor_pcs_fft = TensorMultilinearPCS::::new(config); prove_and_verify(&ml_poly, tensor_pcs_fft); } #[test] fn test_tensor_pcs_ecfft() { type F = halo2curves::secp256k1::Fp; let ml_poly = test_poly_evals(); let mut config = config_base(); let n = ml_poly.evals.len(); let num_cols = config.num_cols(n); let k = ((num_cols * config.expansion_factor).next_power_of_two() as f64).log2() as usize; let (curve, coset_offset) = secp256k1_good_curve(k); config.ecfft_config = Some(ecfft::gen_config_form_curve(curve, coset_offset)); // Test FFT PCS let tensor_pcs_ecf = TensorMultilinearPCS::::new(config); prove_and_verify(&ml_poly, tensor_pcs_ecf); } #[test] fn test_tensor_pcs_naive() { type F = halo2curves::secp256k1::Fp; // FFT config let ml_poly = test_poly_evals(); let n = ml_poly.evals.len(); // Naive config let mut config = config_base(); config.domain_powers = Some(naive::gen_config(config.num_cols(n))); // Test FFT PCS let tensor_pcs_naive = TensorMultilinearPCS::::new(config); prove_and_verify(&ml_poly, tensor_pcs_naive); } }