Browse Source

feat: public inputs

main
Daniel Tehrani 2 years ago
parent
commit
cda41a9374
12 changed files with 274 additions and 143 deletions
  1. +2
    -1
      shockwave_plus/Cargo.toml
  2. +5
    -6
      shockwave_plus/benches/prove.rs
  3. +52
    -37
      shockwave_plus/src/lib.rs
  4. +34
    -2
      shockwave_plus/src/polynomial/ml_poly.rs
  5. +105
    -56
      shockwave_plus/src/r1cs/r1cs.rs
  6. +6
    -6
      shockwave_plus/src/sumcheck/sc_phase_1.rs
  7. +2
    -4
      shockwave_plus/src/sumcheck/sc_phase_2.rs
  8. +2
    -1
      shockwave_plus/src/utils.rs
  9. +2
    -1
      tensor_pcs/Cargo.toml
  10. +24
    -18
      tensor_pcs/src/polynomial/eq_poly.rs
  11. +36
    -1
      tensor_pcs/src/polynomial/sparse_ml_poly.rs
  12. +4
    -10
      tensor_pcs/src/tensor_pcs.rs

+ 2
- 1
shockwave_plus/Cargo.toml

@ -17,7 +17,8 @@ serde = { version = "1.0.152", features = ["derive"] }
criterion = { version = "0.4", features = ["html_reports"] } criterion = { version = "0.4", features = ["html_reports"] }
[[bench]] [[bench]]
name = "prove"
name = "shockwave-plus"
path = "benches/prove.rs"
harness = false harness = false
[features] [features]

+ 5
- 6
shockwave_plus/benches/prove.rs

@ -8,13 +8,12 @@ fn shockwave_plus_bench(c: &mut Criterion) {
type F = halo2curves::secp256k1::Fp; type F = halo2curves::secp256k1::Fp;
for exp in [12, 15, 18] { for exp in [12, 15, 18] {
let num_cons = 2usize.pow(exp);
let num_vars = num_cons;
let num_input = 0;
let num_vars = 2usize.pow(exp);
let num_input = 3;
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_cons, num_vars, num_input);
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
let mut group = c.benchmark_group(format!("ShockwavePlus num_cons: {}", num_cons));
let mut group = c.benchmark_group(format!("ShockwavePlus num_cons: {}", r1cs.num_cons));
let l = 319; let l = 319;
let num_rows = (((2f64 / l as f64).sqrt() * (num_vars as f64).sqrt()) as usize) let num_rows = (((2f64 / l as f64).sqrt() * (num_vars as f64).sqrt()) as usize)
.next_power_of_two() .next_power_of_two()
@ -23,7 +22,7 @@ fn shockwave_plus_bench(c: &mut Criterion) {
group.bench_function("prove", |b| { group.bench_function("prove", |b| {
b.iter(|| { b.iter(|| {
let mut transcript = Transcript::new(b"bench"); let mut transcript = Transcript::new(b"bench");
ShockwavePlus.prove(&witness, &mut transcript);
ShockwavePlus.prove(&witness, &r1cs.public_input, &mut transcript);
}) })
}); });
} }

+ 52
- 37
shockwave_plus/src/lib.rs

@ -33,6 +33,7 @@ pub struct FullSpartanProof {
pub struct ShockwavePlus<F: FieldExt> { pub struct ShockwavePlus<F: FieldExt> {
pub r1cs: R1CS<F>, pub r1cs: R1CS<F>,
pub pcs_witness: TensorMultilinearPCS<F>, pub pcs_witness: TensorMultilinearPCS<F>,
pub pcs_blinder: TensorMultilinearPCS<F>,
} }
impl<F: FieldExt> ShockwavePlus<F> { impl<F: FieldExt> ShockwavePlus<F> {
@ -44,7 +45,7 @@ impl ShockwavePlus {
let expansion_factor = 2; let expansion_factor = 2;
let ecfft_config = rs_config::ecfft::gen_config(num_cols);
let ecfft_config = rs_config::ecfft::gen_config(num_cols.next_power_of_two());
let pcs_config = TensorRSMultilinearPCSConfig::<F> { let pcs_config = TensorRSMultilinearPCSConfig::<F> {
expansion_factor, expansion_factor,
@ -57,17 +58,37 @@ impl ShockwavePlus {
}; };
let pcs_witness = TensorMultilinearPCS::new(pcs_config); let pcs_witness = TensorMultilinearPCS::new(pcs_config);
Self { r1cs, pcs_witness }
let ecfft_config_blinder =
rs_config::ecfft::gen_config((r1cs.z_len() / num_rows).next_power_of_two());
let pcs_blinder_config = TensorRSMultilinearPCSConfig::<F> {
expansion_factor,
domain_powers: None,
fft_domain: None,
ecfft_config: Some(ecfft_config_blinder),
l,
num_entries: r1cs.z_len(),
num_rows,
};
let pcs_blinder = TensorMultilinearPCS::new(pcs_blinder_config);
Self {
r1cs,
pcs_witness,
pcs_blinder,
}
} }
pub fn prove( pub fn prove(
&self, &self,
r1cs_witness: &[F], r1cs_witness: &[F],
r1cs_input: &[F],
transcript: &mut Transcript<F>, transcript: &mut Transcript<F>,
) -> (PartialSpartanProof<F>, Vec<F>) { ) -> (PartialSpartanProof<F>, Vec<F>) {
// Compute the multilinear extension of the witness // Compute the multilinear extension of the witness
assert!(r1cs_witness.len().is_power_of_two());
let witness_poly = SparseMLPoly::from_dense(r1cs_witness.to_vec()); let witness_poly = SparseMLPoly::from_dense(r1cs_witness.to_vec());
let Z = R1CS::construct_z(r1cs_witness, r1cs_input);
// Commit the witness polynomial // Commit the witness polynomial
let comm_witness_timer = start_timer!(|| "Commit witness"); let comm_witness_timer = start_timer!(|| "Commit witness");
@ -82,15 +103,12 @@ impl ShockwavePlus {
// Phase 1 // Phase 1
// ################### // ###################
let m = (self.r1cs.num_vars as f64).log2() as usize;
let m = (self.r1cs.z_len() as f64).log2() as usize;
let tau = transcript.challenge_vec(m); let tau = transcript.challenge_vec(m);
let mut tau_rev = tau.clone();
tau_rev.reverse();
let num_rows = self.r1cs.num_cons;
let Az_poly = self.r1cs.A.mul_vector(num_rows, r1cs_witness);
let Bz_poly = self.r1cs.B.mul_vector(num_rows, r1cs_witness);
let Cz_poly = self.r1cs.C.mul_vector(num_rows, r1cs_witness);
let Az_poly = self.r1cs.A.mul_vector(&Z);
let Bz_poly = self.r1cs.B.mul_vector(&Z);
let Cz_poly = self.r1cs.C.mul_vector(&Z);
// Prove that the // Prove that the
// Q(t) = \sum_{x \in {0, 1}^m} (Az_poly(x) * Bz_poly(x) - Cz_poly(x)) eq(t, x) // Q(t) = \sum_{x \in {0, 1}^m} (Az_poly(x) * Bz_poly(x) - Cz_poly(x)) eq(t, x)
@ -98,8 +116,6 @@ impl ShockwavePlus {
// We evaluate Q(t) at $\tau$ and check that it is zero. // We evaluate Q(t) at $\tau$ and check that it is zero.
let rx = transcript.challenge_vec(m); let rx = transcript.challenge_vec(m);
let mut rx_rev = rx.clone();
rx_rev.reverse();
let sc_phase_1_timer = start_timer!(|| "Sumcheck phase 1"); let sc_phase_1_timer = start_timer!(|| "Sumcheck phase 1");
@ -107,10 +123,10 @@ impl ShockwavePlus {
Az_poly.clone(), Az_poly.clone(),
Bz_poly.clone(), Bz_poly.clone(),
Cz_poly.clone(), Cz_poly.clone(),
tau_rev.clone(),
tau.clone(),
rx.clone(), rx.clone(),
); );
let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs_witness, transcript);
let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs_blinder, transcript);
end_timer!(sc_phase_1_timer); end_timer!(sc_phase_1_timer);
transcript.append_fe(&v_A); transcript.append_fe(&v_A);
@ -128,27 +144,25 @@ impl ShockwavePlus {
self.r1cs.A.clone(), self.r1cs.A.clone(),
self.r1cs.B.clone(), self.r1cs.B.clone(),
self.r1cs.C.clone(), self.r1cs.C.clone(),
r1cs_witness.to_vec(),
Z.clone(),
rx.clone(), rx.clone(),
r.as_slice().try_into().unwrap(), r.as_slice().try_into().unwrap(),
ry.clone(), ry.clone(),
); );
let sc_proof_2 = sc_phase_2.prove(&self.pcs_witness, transcript);
let sc_proof_2 = sc_phase_2.prove(&self.pcs_blinder, transcript);
end_timer!(sc_phase_2_timer); end_timer!(sc_phase_2_timer);
let mut ry_rev = ry.clone();
ry_rev.reverse();
let z_open_timer = start_timer!(|| "Open witness poly"); let z_open_timer = start_timer!(|| "Open witness poly");
// Prove the evaluation of the polynomial Z(y) at ry // Prove the evaluation of the polynomial Z(y) at ry
let z_eval_proof = let z_eval_proof =
self.pcs_witness self.pcs_witness
.open(&committed_witness, &witness_poly, &ry_rev, transcript);
.open(&committed_witness, &witness_poly, &ry[1..], transcript);
end_timer!(z_open_timer); end_timer!(z_open_timer);
// Prove the evaluation of the polynomials A(y), B(y), C(y) at ry // Prove the evaluation of the polynomials A(y), B(y), C(y) at ry
let rx_ry = vec![ry_rev, rx_rev].concat();
let rx_ry = vec![ry, rx].concat();
( (
PartialSpartanProof { PartialSpartanProof {
z_comm: witness_comm, z_comm: witness_comm,
@ -174,11 +188,9 @@ impl ShockwavePlus {
let B_mle = self.r1cs.B.to_ml_extension(); let B_mle = self.r1cs.B.to_ml_extension();
let C_mle = self.r1cs.C.to_ml_extension(); let C_mle = self.r1cs.C.to_ml_extension();
let m = (self.r1cs.num_vars as f64).log2() as usize;
let m = (self.r1cs.z_len() as f64).log2() as usize;
let tau = transcript.challenge_vec(m); let tau = transcript.challenge_vec(m);
let rx = transcript.challenge_vec(m); let rx = transcript.challenge_vec(m);
let mut rx_rev = rx.clone();
rx_rev.reverse();
transcript.append_fe(&partial_proof.sc_proof_1.blinder_poly_sum); transcript.append_fe(&partial_proof.sc_proof_1.blinder_poly_sum);
transcript.append_bytes(&partial_proof.sc_proof_1.blinder_poly_eval_proof.u_hat_comm); transcript.append_bytes(&partial_proof.sc_proof_1.blinder_poly_eval_proof.u_hat_comm);
@ -187,7 +199,7 @@ impl ShockwavePlus {
let ex = SumCheckPhase1::verify_round_polys(&partial_proof.sc_proof_1, &rx, rho); let ex = SumCheckPhase1::verify_round_polys(&partial_proof.sc_proof_1, &rx, rho);
self.pcs_witness.verify(
self.pcs_blinder.verify(
&partial_proof.sc_proof_1.blinder_poly_eval_proof, &partial_proof.sc_proof_1.blinder_poly_eval_proof,
transcript, transcript,
); );
@ -198,7 +210,7 @@ impl ShockwavePlus {
let v_C = partial_proof.v_C; let v_C = partial_proof.v_C;
let T_1_eq = EqPoly::new(tau); let T_1_eq = EqPoly::new(tau);
let T_1 = (v_A * v_B - v_C) * T_1_eq.eval(&rx_rev)
let T_1 = (v_A * v_B - v_C) * T_1_eq.eval(&rx)
+ rho * partial_proof.sc_proof_1.blinder_poly_eval_proof.y; + rho * partial_proof.sc_proof_1.blinder_poly_eval_proof.y;
assert_eq!(T_1, ex); assert_eq!(T_1, ex);
@ -223,18 +235,15 @@ impl ShockwavePlus {
let final_poly_eval = let final_poly_eval =
SumCheckPhase2::verify_round_polys(T_2, &partial_proof.sc_proof_2, &ry); SumCheckPhase2::verify_round_polys(T_2, &partial_proof.sc_proof_2, &ry);
self.pcs_witness.verify(
self.pcs_blinder.verify(
&partial_proof.sc_proof_2.blinder_poly_eval_proof, &partial_proof.sc_proof_2.blinder_poly_eval_proof,
transcript, transcript,
); );
let mut ry_rev = ry.clone();
ry_rev.reverse();
assert_eq!(partial_proof.z_eval_proof.x, ry[1..]);
let rx_ry = [rx, ry.clone()].concat();
let rx_ry = [rx, ry].concat();
assert_eq!(partial_proof.z_eval_proof.x, ry_rev);
let z_eval = partial_proof.z_eval_proof.y;
let witness_eval = partial_proof.z_eval_proof.y;
let A_eval = A_mle.eval(&rx_ry); let A_eval = A_mle.eval(&rx_ry);
let B_eval = B_mle.eval(&rx_ry); let B_eval = B_mle.eval(&rx_ry);
let C_eval = C_mle.eval(&rx_ry); let C_eval = C_mle.eval(&rx_ry);
@ -242,6 +251,12 @@ impl ShockwavePlus {
self.pcs_witness self.pcs_witness
.verify(&partial_proof.z_eval_proof, transcript); .verify(&partial_proof.z_eval_proof, transcript);
let input = R1CS::construct_z(&vec![F::ZERO; self.r1cs.num_vars], &self.r1cs.public_input);
let input_poly = SparseMLPoly::from_dense(input);
let input_poly_eval = input_poly.eval(&ry);
let z_eval = (F::ONE - ry[0]) * witness_eval + input_poly_eval;
let T_opened = (r_A * A_eval + r_B * B_eval + r_C * C_eval) * z_eval let T_opened = (r_A * A_eval + r_B * B_eval + r_C * C_eval) * z_eval
+ rho_2 * partial_proof.sc_proof_2.blinder_poly_eval_proof.y; + rho_2 * partial_proof.sc_proof_2.blinder_poly_eval_proof.y;
assert_eq!(T_opened, final_poly_eval); assert_eq!(T_opened, final_poly_eval);
@ -256,17 +271,17 @@ mod tests {
fn test_shockwave_plus() { fn test_shockwave_plus() {
type F = halo2curves::secp256k1::Fp; type F = halo2curves::secp256k1::Fp;
let num_cons = 2usize.pow(6);
let num_vars = num_cons;
let num_input = 0;
let num_vars = 2usize.pow(7);
let num_input = 3;
let l = 10; let l = 10;
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_cons, num_vars, num_input);
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
let num_rows = 4; let num_rows = 4;
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, num_rows); let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, num_rows);
let mut prover_transcript = Transcript::new(b"bench"); let mut prover_transcript = Transcript::new(b"bench");
let (partial_proof, _) = ShockwavePlus.prove(&witness, &mut prover_transcript);
let (partial_proof, _) =
ShockwavePlus.prove(&witness, &r1cs.public_input, &mut prover_transcript);
let mut verifier_transcript = Transcript::new(b"bench"); let mut verifier_transcript = Transcript::new(b"bench");
ShockwavePlus.verify_partial(&partial_proof, &mut verifier_transcript); ShockwavePlus.verify_partial(&partial_proof, &mut verifier_transcript);

+ 34
- 2
shockwave_plus/src/polynomial/ml_poly.rs

@ -26,14 +26,46 @@ impl MlPoly {
// Evaluate the multilinear extension of the polynomial `a`, at point `t`. // Evaluate the multilinear extension of the polynomial `a`, at point `t`.
// `a` is in evaluation form. // `a` is in evaluation form.
// `t` should be in big-endian form.
pub fn eval(&self, t: &[F]) -> F { pub fn eval(&self, t: &[F]) -> F {
let n = self.evals.len(); let n = self.evals.len();
debug_assert_eq!((n as f64).log2() as usize, t.len()); debug_assert_eq!((n as f64).log2() as usize, t.len());
// Evaluate the multilinear extension of the polynomial `a`,
// over the boolean hypercube
let eq_evals = EqPoly::new(t.to_vec()).evals(); let eq_evals = EqPoly::new(t.to_vec()).evals();
Self::dot_prod(&self.evals, &eq_evals) Self::dot_prod(&self.evals, &eq_evals)
} }
} }
#[cfg(test)]
mod tests {
use super::*;
type F = halo2curves::secp256k1::Fp;
use halo2curves::ff::Field;
#[test]
fn test_ml_poly_eval() {
let num_vars = 4;
let num_evals = 2usize.pow(num_vars as u32);
let evals = (0..num_evals)
.map(|x| F::from(x as u64))
.collect::<Vec<F>>();
let ml_poly = MlPoly::new(evals.clone());
let eval_last = ml_poly.eval(&[F::ONE, F::ONE, F::ONE, F::ONE]);
assert_eq!(
eval_last,
evals[evals.len() - 1],
"The last evaluation is not correct"
);
let eval_first = ml_poly.eval(&[F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
assert_eq!(eval_first, evals[0], "The first evaluation is not correct");
let eval_second = ml_poly.eval(&[F::ZERO, F::ZERO, F::ZERO, F::ONE]);
assert_eq!(
eval_second, evals[1],
"The second evaluation is not correct"
);
}
}

+ 105
- 56
shockwave_plus/src/r1cs/r1cs.rs

@ -20,7 +20,6 @@ where
F: FieldExt, F: FieldExt,
{ {
pub fn new(entries: Vec<SparseMatrixEntry<F>>, num_cols: usize, num_rows: usize) -> Self { pub fn new(entries: Vec<SparseMatrixEntry<F>>, num_cols: usize, num_rows: usize) -> Self {
assert!((num_cols * num_rows).is_power_of_two());
Self { Self {
entries, entries,
num_cols, num_cols,
@ -28,8 +27,9 @@ where
} }
} }
pub fn mul_vector(&self, num_rows: usize, vec: &[F]) -> Vec<F> {
let mut result = vec![F::ZERO; num_rows];
pub fn mul_vector(&self, vec: &[F]) -> Vec<F> {
debug_assert_eq!(vec.len(), self.num_cols);
let mut result = vec![F::ZERO; self.num_rows];
let entries = &self.entries; let entries = &self.entries;
for i in 0..entries.len() { for i in 0..entries.len() {
let row = entries[i].row; let row = entries[i].row;
@ -149,12 +149,25 @@ where
result result
} }
pub fn produce_synthetic_r1cs(
num_cons: usize,
num_vars: usize,
num_input: usize,
) -> (Self, Vec<F>) {
// assert_eq!(num_cons, num_vars);
pub fn z_len(&self) -> usize {
((self.num_vars.next_power_of_two() + 1) + self.num_input).next_power_of_two()
}
pub fn construct_z(witness: &[F], public_input: &[F]) -> Vec<F> {
// Z = (witness, 1, io)
let mut z = witness.to_vec();
// Pad the witness part of z to have a power of two length
z.resize(z.len().next_power_of_two(), F::ZERO);
z.push(F::ONE);
z.extend(public_input.clone());
// Pad the (1, io) part of z to have a power of two length
z.resize(z.len().next_power_of_two(), F::ZERO);
z
}
pub fn produce_synthetic_r1cs(num_vars: usize, num_input: usize) -> (Self, Vec<F>) {
let mut public_input = Vec::with_capacity(num_input); let mut public_input = Vec::with_capacity(num_input);
let mut witness = Vec::with_capacity(num_vars); let mut witness = Vec::with_capacity(num_vars);
@ -166,16 +179,17 @@ where
witness.push(F::from((i + 1) as u64)); witness.push(F::from((i + 1) as u64));
} }
let z: Vec<F> = vec![public_input.clone(), witness.clone()].concat();
let z = Self::construct_z(&witness, &public_input);
let mut A_entries: Vec<SparseMatrixEntry<F>> = vec![]; let mut A_entries: Vec<SparseMatrixEntry<F>> = vec![];
let mut B_entries: Vec<SparseMatrixEntry<F>> = vec![]; let mut B_entries: Vec<SparseMatrixEntry<F>> = vec![];
let mut C_entries: Vec<SparseMatrixEntry<F>> = vec![]; let mut C_entries: Vec<SparseMatrixEntry<F>> = vec![];
let num_cons = z.len();
for i in 0..num_cons { for i in 0..num_cons {
let A_col = i % num_vars;
let B_col = (i + 1) % num_vars;
let C_col = (i + 2) % num_vars;
let A_col = i % num_cons;
let B_col = (i + 1) % num_cons;
let C_col = (i + 2) % num_cons;
// For the i'th constraint, // For the i'th constraint,
// add the value 1 at the (i % num_vars)th column of A, B. // add the value 1 at the (i % num_vars)th column of A, B.
@ -183,28 +197,35 @@ where
// we apply multiplication since the Hadamard product is computed for Az ・ Bz, // we apply multiplication since the Hadamard product is computed for Az ・ Bz,
// We only _enable_ a single variable in each constraint. // We only _enable_ a single variable in each constraint.
let AB = if z[C_col] == F::ZERO { F::ZERO } else { F::ONE };
A_entries.push(SparseMatrixEntry { A_entries.push(SparseMatrixEntry {
row: i, row: i,
col: A_col, col: A_col,
val: F::ONE,
val: AB,
}); });
B_entries.push(SparseMatrixEntry { B_entries.push(SparseMatrixEntry {
row: i, row: i,
col: B_col, col: B_col,
val: F::ONE,
val: AB,
}); });
C_entries.push(SparseMatrixEntry { C_entries.push(SparseMatrixEntry {
row: i, row: i,
col: C_col, col: C_col,
val: (z[A_col] * z[B_col]) * z[C_col].invert().unwrap(),
val: if z[C_col] == F::ZERO {
F::ZERO
} else {
(z[A_col] * z[B_col]) * z[C_col].invert().unwrap()
},
}); });
} }
let A = Matrix::new(A_entries, num_vars, num_cons);
let B = Matrix::new(B_entries, num_vars, num_cons);
let num_cols = z.len();
let num_rows = num_cols;
let C = Matrix::new(C_entries, num_vars, num_cons);
let A = Matrix::new(A_entries, num_cols, num_rows);
let B = Matrix::new(B_entries, num_cols, num_rows);
let C = Matrix::new(C_entries, num_cols, num_rows);
( (
Self { Self {
@ -220,14 +241,11 @@ where
) )
} }
pub fn is_sat(&self, witness: &Vec<F>, public_input: &Vec<F>) -> bool {
let mut z = Vec::with_capacity(witness.len() + public_input.len() + 1);
z.extend(public_input);
z.extend(witness);
let Az = self.A.mul_vector(self.num_cons, &z);
let Bz = self.B.mul_vector(self.num_cons, &z);
let Cz = self.C.mul_vector(self.num_cons, &z);
pub fn is_sat(&self, witness: &[F], public_input: &[F]) -> bool {
let z = Self::construct_z(witness, public_input);
let Az = self.A.mul_vector(&z);
let Bz = self.B.mul_vector(&z);
let Cz = self.C.mul_vector(&z);
Self::hadamard_prod(&Az, &Bz) == Cz Self::hadamard_prod(&Az, &Bz) == Cz
} }
@ -235,6 +253,9 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use halo2curves::ff::Field;
use crate::utils::boolean_hypercube; use crate::utils::boolean_hypercube;
use super::*; use super::*;
@ -243,11 +264,11 @@ mod tests {
#[test] #[test]
fn test_r1cs() { fn test_r1cs() {
let num_cons = 2usize.pow(5);
let num_vars = num_cons;
let num_input = 0;
let num_cons = 10;
let num_input = 3;
let num_vars = num_cons - num_input;
let (r1cs, mut witness) = R1CS::<F>::produce_synthetic_r1cs(num_cons, num_vars, num_input);
let (r1cs, mut witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
assert_eq!(witness.len(), num_vars); assert_eq!(witness.len(), num_vars);
assert_eq!(r1cs.public_input.len(), num_input); assert_eq!(r1cs.public_input.len(), num_input);
@ -255,51 +276,79 @@ mod tests {
assert!(r1cs.is_sat(&witness, &r1cs.public_input)); assert!(r1cs.is_sat(&witness, &r1cs.public_input));
// Should assert if the witness is invalid // Should assert if the witness is invalid
witness[0] = witness[0] + F::one();
assert!(r1cs.is_sat(&r1cs.public_input, &witness) == false);
witness[0] = witness[0] - F::one();
witness[0] = witness[0] + F::ONE;
assert!(r1cs.is_sat(&witness, &r1cs.public_input) == false);
witness[0] = witness[0] - F::ONE;
/*
// Should assert if the public input is invalid // Should assert if the public input is invalid
let mut public_input = r1cs.public_input.clone(); let mut public_input = r1cs.public_input.clone();
public_input[0] = public_input[0] + F::one();
public_input[0] = public_input[0] + F::ONE;
assert!(r1cs.is_sat(&witness, &public_input) == false); assert!(r1cs.is_sat(&witness, &public_input) == false);
*/
public_input[0] = public_input[0] - F::ONE;
// Test MLE // Test MLE
let s = (num_vars as f64).log2() as usize;
let A_mle = r1cs.A.to_ml_extension(); let A_mle = r1cs.A.to_ml_extension();
let B_mle = r1cs.B.to_ml_extension(); let B_mle = r1cs.B.to_ml_extension();
let C_mle = r1cs.C.to_ml_extension(); let C_mle = r1cs.C.to_ml_extension();
let Z_mle = MlPoly::new(witness);
let z = R1CS::construct_z(&witness, &public_input);
let Z_mle = MlPoly::new(z);
let s = Z_mle.num_vars;
for c in &boolean_hypercube(s) { for c in &boolean_hypercube(s) {
let mut eval_a = F::zero();
let mut eval_b = F::zero();
let mut eval_c = F::zero();
let mut eval_a = F::ZERO;
let mut eval_b = F::ZERO;
let mut eval_c = F::ZERO;
for b in &boolean_hypercube(s) { for b in &boolean_hypercube(s) {
let mut b_rev = b.clone();
b_rev.reverse();
let z_eval = Z_mle.eval(&b_rev);
let mut eval_matrix = [b.as_slice(), c.as_slice()].concat();
eval_matrix.reverse();
let z_eval = Z_mle.eval(&b);
let eval_matrix = [c.as_slice(), b.as_slice()].concat();
eval_a += A_mle.eval(&eval_matrix) * z_eval; eval_a += A_mle.eval(&eval_matrix) * z_eval;
eval_b += B_mle.eval(&eval_matrix) * z_eval; eval_b += B_mle.eval(&eval_matrix) * z_eval;
eval_c += C_mle.eval(&eval_matrix) * z_eval; eval_c += C_mle.eval(&eval_matrix) * z_eval;
} }
let eval_con = eval_a * eval_b - eval_c; let eval_con = eval_a * eval_b - eval_c;
assert_eq!(eval_con, F::zero());
assert_eq!(eval_con, F::ZERO);
} }
} }
/*
#[test] #[test]
fn test_fast_uni_eval() {
let (r1cs, _) = R1CS::<F>::produce_synthetic_r1cs(8, 8, 0);
fn test_construct_z() {
let num_cons = 10;
let num_input = 3;
let num_vars = num_cons - num_input;
let eval_at = F::from(33);
let result = r1cs.A.fast_uni_eval(r1cs.num_vars, eval_at);
println!("result: {:?}", result);
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
let Z = R1CS::construct_z(&witness, &r1cs.public_input);
// The first num_vars should equal to Z
let Z_mle = SparseMLPoly::from_dense(Z.clone());
for (i, b) in boolean_hypercube(Z_mle.num_vars - 1).iter().enumerate() {
assert_eq!(Z[i], Z_mle.eval(&[&[F::ZERO], b.as_slice()].concat()));
}
for (i, b) in boolean_hypercube(Z_mle.num_vars - 1).iter().enumerate() {
if i == 0 {
assert_eq!(F::ONE, Z_mle.eval(&[&[F::ONE], b.as_slice()].concat()));
} else if (i - 1) < r1cs.public_input.len() {
assert_eq!(
r1cs.public_input[i - 1],
Z_mle.eval(&[&[F::ONE], b.as_slice()].concat())
);
} else {
assert_eq!(F::ZERO, Z_mle.eval(&[&[F::ONE], b.as_slice()].concat()));
}
}
}
#[test]
fn test_z_len() {
let num_cons = 10;
let num_input = 3;
let num_vars = num_cons - num_input;
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
let z = R1CS::construct_z(&witness, &r1cs.public_input);
assert_eq!(z.len(), r1cs.z_len());
} }
*/
} }

+ 6
- 6
shockwave_plus/src/sumcheck/sc_phase_1.rs

@ -120,10 +120,12 @@ impl SumCheckPhase1 {
let v_C = C_table[0]; let v_C = C_table[0];
// Prove the evaluation of the blinder polynomial at rx. // Prove the evaluation of the blinder polynomial at rx.
let mut rx_rev = self.challenge.clone();
rx_rev.reverse();
let blinder_poly_eval_proof =
pcs.open(&blinder_poly_comm, &blinder_poly, &rx_rev, transcript);
let blinder_poly_eval_proof = pcs.open(
&blinder_poly_comm,
&blinder_poly,
&self.challenge,
transcript,
);
( (
SCPhase1Proof { SCPhase1Proof {
@ -141,8 +143,6 @@ impl SumCheckPhase1 {
let zero = F::ZERO; let zero = F::ZERO;
let one = F::ONE; let one = F::ONE;
println!("v phase 1 rho = {:?}", rho);
// target = 0 + rho * blinder_poly_sum // target = 0 + rho * blinder_poly_sum
let mut target = rho * proof.blinder_poly_sum; let mut target = rho * proof.blinder_poly_sum;
for (i, round_poly) in proof.round_polys.iter().enumerate() { for (i, round_poly) in proof.round_polys.iter().enumerate() {

+ 2
- 4
shockwave_plus/src/sumcheck/sc_phase_2.rs

@ -128,11 +128,9 @@ impl SumCheckPhase2 {
round_polys.push(round_poly); round_polys.push(round_poly);
} }
let mut r_y_rev = self.challenge.clone();
r_y_rev.reverse();
let ry = self.challenge.clone();
let blinder_poly_eval_proof =
pcs.open(&blinder_poly_comm, &blinder_poly, &r_y_rev, transcript);
let blinder_poly_eval_proof = pcs.open(&blinder_poly_comm, &blinder_poly, &ry, transcript);
SCPhase2Proof { SCPhase2Proof {
round_polys, round_polys,

+ 2
- 1
shockwave_plus/src/utils.rs

@ -1,6 +1,6 @@
use crate::FieldExt; use crate::FieldExt;
// Returns a vector of vectors of length m, where each vector is a boolean vector (little endian)
// Returns a vector of vectors of length m, where each vector is a boolean vector (big endian)
pub fn boolean_hypercube<F: FieldExt>(m: usize) -> Vec<Vec<F>> { pub fn boolean_hypercube<F: FieldExt>(m: usize) -> Vec<Vec<F>> {
let n = 2usize.pow(m as u32); let n = 2usize.pow(m as u32);
@ -12,6 +12,7 @@ pub fn boolean_hypercube(m: usize) -> Vec> {
let i_b = F::from((i >> j & 1) as u64); let i_b = F::from((i >> j & 1) as u64);
tmp.push(i_b); tmp.push(i_b);
} }
tmp.reverse();
boolean_hypercube.push(tmp); boolean_hypercube.push(tmp);
} }

+ 2
- 1
tensor_pcs/Cargo.toml

@ -17,5 +17,6 @@ halo2curves = "0.1.0"
criterion = { version = "0.4", features = ["html_reports"] } criterion = { version = "0.4", features = ["html_reports"] }
[[bench]] [[bench]]
name = "prove"
name = "pcs"
path = "benches/prove.rs"
harness = false harness = false

+ 24
- 18
tensor_pcs/src/polynomial/eq_poly.rs

@ -7,6 +7,7 @@ pub struct EqPoly {
} }
impl<F: FieldExt> EqPoly<F> { impl<F: FieldExt> EqPoly<F> {
// `t` should be in big-endian.
pub fn new(t: Vec<F>) -> Self { pub fn new(t: Vec<F>) -> Self {
Self { t } Self { t }
} }
@ -23,18 +24,18 @@ impl EqPoly {
// Copied from microsoft/Spartan // Copied from microsoft/Spartan
pub fn evals(&self) -> Vec<F> { pub fn evals(&self) -> Vec<F> {
let ell = self.t.len(); // 4
let ell = self.t.len();
let mut evals: Vec<F> = vec![F::ONE; 2usize.pow(ell as u32)]; let mut evals: Vec<F> = vec![F::ONE; 2usize.pow(ell as u32)];
let mut size = 1; let mut size = 1;
for j in 0..ell { for j in 0..ell {
// in each iteration, we double the size of chis // in each iteration, we double the size of chis
size *= 2; // 2 4 8 16
size *= 2;
for i in (0..size).rev().step_by(2) { for i in (0..size).rev().step_by(2) {
// copy each element from the prior iteration twice // copy each element from the prior iteration twice
let scalar = evals[i / 2]; // i = 0, 2, 4, 7
evals[i] = scalar * self.t[j]; // (1 * t0)(1 * t1)
evals[i - 1] = scalar - evals[i]; // 1 - (1 * t0)(1 * t1)
let scalar = evals[i / 2];
evals[i] = scalar * self.t[j];
evals[i - 1] = scalar - evals[i];
} }
} }
evals evals
@ -44,25 +45,30 @@ impl EqPoly {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::polynomial::sparse_ml_poly::SparseMLPoly;
use halo2curves::ff::Field;
type F = halo2curves::secp256k1::Fp; type F = halo2curves::secp256k1::Fp;
pub fn dot_prod<F: FieldExt>(x: &[F], y: &[F]) -> F {
assert_eq!(x.len(), y.len());
let mut result = F::ZERO;
for i in 0..x.len() {
result += x[i] * y[i];
}
result
}
use halo2curves::ff::Field;
#[test] #[test]
fn test_eq_poly() { fn test_eq_poly() {
let m = 4; let m = 4;
let t = (0..m).map(|i| F::from((i + 33) as u64)).collect::<Vec<F>>(); let t = (0..m).map(|i| F::from((i + 33) as u64)).collect::<Vec<F>>();
let eq_poly = EqPoly::new(t.clone()); let eq_poly = EqPoly::new(t.clone());
eq_poly.evals();
let evals = eq_poly.evals();
let eval_first = eq_poly.eval(&[F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
assert_eq!(eval_first, evals[0], "The first evaluation is not correct");
let eval_second = eq_poly.eval(&[F::ZERO, F::ZERO, F::ZERO, F::ONE]);
assert_eq!(
eval_second, evals[1],
"The second evaluation is not correct"
);
let eval_last = eq_poly.eval(&[F::ONE, F::ONE, F::ONE, F::ONE]);
assert_eq!(
eval_last,
evals[evals.len() - 1],
"The last evaluation is not correct"
);
} }
} }

+ 36
- 1
tensor_pcs/src/polynomial/sparse_ml_poly.rs

@ -14,8 +14,8 @@ impl SparseMLPoly {
pub fn from_dense(dense_evals: Vec<F>) -> Self { pub fn from_dense(dense_evals: Vec<F>) -> Self {
let sparse_evals = dense_evals let sparse_evals = dense_evals
.iter() .iter()
.filter(|eval| **eval != F::ZERO)
.enumerate() .enumerate()
.filter(|(_, eval)| **eval != F::ZERO)
.map(|(i, eval)| (i, *eval)) .map(|(i, eval)| (i, *eval))
.collect::<Vec<(usize, F)>>(); .collect::<Vec<(usize, F)>>();
let num_vars = (dense_evals.len() as f64).log2() as usize; let num_vars = (dense_evals.len() as f64).log2() as usize;
@ -26,7 +26,9 @@ impl SparseMLPoly {
} }
} }
// `t` should be in big-endian form.
pub fn eval(&self, t: &[F]) -> F { pub fn eval(&self, t: &[F]) -> F {
assert_eq!(self.num_vars, t.len());
// Evaluate the multilinear extension of the polynomial `a`, // Evaluate the multilinear extension of the polynomial `a`,
// over the boolean hypercube // over the boolean hypercube
@ -42,3 +44,36 @@ impl SparseMLPoly {
result result
} }
} }
#[cfg(test)]
mod tests {
use super::*;
type F = halo2curves::secp256k1::Fp;
use halo2curves::ff::Field;
#[test]
fn test_sparse_ml_poly_eval() {
let num_vars = 4;
let num_evals = 2usize.pow(num_vars as u32);
let evals = (0..num_evals)
.map(|x| F::from((x as u64) as u64))
.collect::<Vec<F>>();
let ml_poly = SparseMLPoly::from_dense(evals.clone());
let eval_last = ml_poly.eval(&[F::ONE, F::ONE, F::ONE, F::ONE]);
assert_eq!(
eval_last,
evals[evals.len() - 1],
"The last evaluation is not correct"
);
let eval_first = ml_poly.eval(&[F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
assert_eq!(eval_first, evals[0], "The first evaluation is not correct");
let eval_second = ml_poly.eval(&[F::ZERO, F::ZERO, F::ZERO, F::ONE]);
assert_eq!(
eval_second, evals[1],
"The second evaluation is not correct"
);
}
}

+ 4
- 10
tensor_pcs/src/tensor_pcs.rs

@ -120,11 +120,9 @@ impl TensorMultilinearPCS {
// ######################################## // ########################################
// Get the evaluation point // Get the evaluation point
let mut point_rev = point.to_vec();
point_rev.reverse();
let log2_num_rows = (num_rows as f64).log2() as usize; let log2_num_rows = (num_rows as f64).log2() as usize;
let q1 = EqPoly::new(point_rev[0..log2_num_rows].to_vec()).evals();
let q1 = EqPoly::new(point[0..log2_num_rows].to_vec()).evals();
let eval_r_prime = rlc_rows(blinder, &q1); let eval_r_prime = rlc_rows(blinder, &q1);
@ -134,7 +132,7 @@ impl TensorMultilinearPCS {
TensorMLOpening { TensorMLOpening {
x: point.to_vec(), x: point.to_vec(),
y: poly.eval(&point_rev),
y: poly.eval(&point),
eval_query_leaves: eval_queries, eval_query_leaves: eval_queries,
test_query_leaves: test_queries, test_query_leaves: test_queries,
u_hat_comm: u_hat_comm.committed_tree.root(), u_hat_comm: u_hat_comm.committed_tree.root(),
@ -163,7 +161,6 @@ impl TensorMultilinearPCS {
// ######################################## // ########################################
let r_u = transcript.challenge_vec(num_rows); let r_u = transcript.challenge_vec(num_rows);
println!("r_u = {:?}", r_u);
let test_u_prime_rs_codeword = self let test_u_prime_rs_codeword = self
.rs_encode(&opening.test_u_prime) .rs_encode(&opening.test_u_prime)
@ -197,12 +194,9 @@ impl TensorMultilinearPCS {
// Verify evaluation phase // Verify evaluation phase
// ######################################## // ########################################
let mut x_rev = opening.x.clone();
x_rev.reverse();
let log2_num_rows = (num_rows as f64).log2() as usize; let log2_num_rows = (num_rows as f64).log2() as usize;
let q1 = EqPoly::new(x_rev[0..log2_num_rows].to_vec()).evals();
let q2 = EqPoly::new(x_rev[log2_num_rows..].to_vec()).evals();
let q1 = EqPoly::new(opening.x[0..log2_num_rows].to_vec()).evals();
let q2 = EqPoly::new(opening.x[log2_num_rows..].to_vec()).evals();
let eval_u_prime_rs_codeword = self let eval_u_prime_rs_codeword = self
.rs_encode(&opening.eval_u_prime) .rs_encode(&opening.eval_u_prime)

Loading…
Cancel
Save