feat: use hardcoded good curves

This commit is contained in:
Daniel Tehrani
2023-07-30 15:05:45 -07:00
parent a465129225
commit 3546f03844
13 changed files with 632 additions and 101 deletions

View File

@@ -2,7 +2,8 @@
use criterion::{criterion_group, criterion_main, Criterion};
use shockwave_plus::ShockwavePlus;
use shockwave_plus::R1CS;
use tensor_pcs::Transcript;
use tensor_pcs::rs_config::good_curves::secp256k1::secp256k1_good_curve;
use tensor_pcs::{det_num_cols, Transcript};
fn shockwave_plus_bench(c: &mut Criterion) {
type F = halo2curves::secp256k1::Fp;
@@ -15,14 +16,23 @@ fn shockwave_plus_bench(c: &mut Criterion) {
let mut group = c.benchmark_group(format!("ShockwavePlus num_cons: {}", r1cs.num_cons));
let l = 319;
let num_rows = (((2f64 / l as f64).sqrt() * (num_vars as f64).sqrt()) as usize)
.next_power_of_two()
/ 2;
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, num_rows);
let num_cols = det_num_cols(r1cs.z_len(), l);
let (good_curve, coset_offset) =
secp256k1_good_curve((num_cols as f64).log2() as usize + 1);
group.bench_function("config", |b| {
b.iter(|| {
ShockwavePlus::new(r1cs.clone(), l, good_curve, coset_offset);
})
});
let shockwave_plus = ShockwavePlus::new(r1cs.clone(), l, good_curve, coset_offset);
group.bench_function("prove", |b| {
b.iter(|| {
let mut transcript = Transcript::new(b"bench");
ShockwavePlus.prove(&witness, &r1cs.public_input, &mut transcript);
shockwave_plus.prove(&witness, &r1cs.public_input, &mut transcript);
})
});
}

View File

@@ -6,10 +6,10 @@ mod sumcheck;
use ark_std::{end_timer, start_timer};
use serde::{Deserialize, Serialize};
use sumcheck::{SCPhase1Proof, SCPhase2Proof, SumCheckPhase1, SumCheckPhase2};
use tensor_pcs::{ecfft::GoodCurve, *};
// Exports
pub use r1cs::R1CS;
pub use tensor_pcs::*;
#[derive(Serialize, Deserialize)]
pub struct PartialSpartanProof<F: FieldExt> {
@@ -30,21 +30,15 @@ pub struct FullSpartanProof<F: FieldExt> {
}
pub struct ShockwavePlus<F: FieldExt> {
pub r1cs: R1CS<F>,
pub pcs_witness: TensorMultilinearPCS<F>,
pub pcs_blinder: TensorMultilinearPCS<F>,
r1cs: R1CS<F>,
pcs: TensorMultilinearPCS<F>,
}
impl<F: FieldExt> ShockwavePlus<F> {
pub fn new(r1cs: R1CS<F>, l: usize, num_rows: usize) -> Self {
let num_cols = r1cs.num_vars / num_rows;
// Make sure that there are enough columns to run the l queries
assert!(num_cols > l);
pub fn new(r1cs: R1CS<F>, l: usize, good_curve: GoodCurve<F>, coset_offset: (F, F)) -> Self {
let expansion_factor = 2;
let ecfft_config = rs_config::ecfft::gen_config(num_cols.next_power_of_two());
let ecfft_config = rs_config::ecfft::gen_config_form_curve(good_curve, coset_offset);
let pcs_config = TensorRSMultilinearPCSConfig::<F> {
expansion_factor,
@@ -52,31 +46,21 @@ impl<F: FieldExt> ShockwavePlus<F> {
fft_domain: None,
ecfft_config: Some(ecfft_config),
l,
num_entries: r1cs.num_vars,
num_rows,
};
let pcs_witness = TensorMultilinearPCS::new(pcs_config);
let min_num_entries = r1cs.num_vars.next_power_of_two();
let min_num_cols = pcs_config.num_cols(min_num_entries);
let ecfft_config_blinder =
rs_config::ecfft::gen_config((r1cs.z_len() / num_rows).next_power_of_two());
let pcs_blinder_config = TensorRSMultilinearPCSConfig::<F> {
expansion_factor,
domain_powers: None,
fft_domain: None,
ecfft_config: Some(ecfft_config_blinder),
l,
num_entries: r1cs.z_len(),
num_rows,
};
let max_num_entries = r1cs.z_len().next_power_of_two();
let max_num_cols = pcs_config.num_cols(max_num_entries);
// Make sure that there are enough columns to run the l queries
assert!(min_num_cols > l);
let pcs_blinder = TensorMultilinearPCS::new(pcs_blinder_config);
assert_eq!(good_curve.k, (max_num_cols as f64).log2() as usize + 1);
Self {
r1cs,
pcs_witness,
pcs_blinder,
}
let pcs = TensorMultilinearPCS::new(pcs_config);
Self { r1cs, pcs }
}
pub fn prove(
@@ -91,7 +75,7 @@ impl<F: FieldExt> ShockwavePlus<F> {
// Commit the witness polynomial
let comm_witness_timer = start_timer!(|| "Commit witness");
let committed_witness = self.pcs_witness.commit(&witness_poly);
let committed_witness = self.pcs.commit(&witness_poly);
let witness_comm = committed_witness.committed_tree.root;
end_timer!(comm_witness_timer);
@@ -125,7 +109,8 @@ impl<F: FieldExt> ShockwavePlus<F> {
tau.clone(),
rx.clone(),
);
let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs_blinder, transcript);
let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs, transcript);
end_timer!(sc_phase_1_timer);
transcript.append_fe(&v_A);
@@ -149,14 +134,14 @@ impl<F: FieldExt> ShockwavePlus<F> {
ry.clone(),
);
let sc_proof_2 = sc_phase_2.prove(&self.pcs_blinder, transcript);
let sc_proof_2 = sc_phase_2.prove(&self.pcs, transcript);
end_timer!(sc_phase_2_timer);
let z_open_timer = start_timer!(|| "Open witness poly");
// Prove the evaluation of the polynomial Z(y) at ry
let z_eval_proof =
self.pcs_witness
.open(&committed_witness, &witness_poly, &ry[1..], transcript);
let z_eval_proof = self
.pcs
.open(&committed_witness, &witness_poly, &ry[1..], transcript);
end_timer!(z_open_timer);
// Prove the evaluation of the polynomials A(y), B(y), C(y) at ry
@@ -198,7 +183,7 @@ impl<F: FieldExt> ShockwavePlus<F> {
let ex = SumCheckPhase1::verify_round_polys(&partial_proof.sc_proof_1, &rx, rho);
self.pcs_blinder.verify(
self.pcs.verify(
&partial_proof.sc_proof_1.blinder_poly_eval_proof,
transcript,
);
@@ -234,7 +219,7 @@ impl<F: FieldExt> ShockwavePlus<F> {
let final_poly_eval =
SumCheckPhase2::verify_round_polys(T_2, &partial_proof.sc_proof_2, &ry);
self.pcs_blinder.verify(
self.pcs.verify(
&partial_proof.sc_proof_2.blinder_poly_eval_proof,
transcript,
);
@@ -247,8 +232,7 @@ impl<F: FieldExt> ShockwavePlus<F> {
let B_eval = B_mle.eval(&rx_ry);
let C_eval = C_mle.eval(&rx_ry);
self.pcs_witness
.verify(&partial_proof.z_eval_proof, transcript);
self.pcs.verify(&partial_proof.z_eval_proof, transcript);
let witness_len = self.r1cs.num_vars.next_power_of_two();
let input = (0..self.r1cs.num_input)
@@ -270,20 +254,25 @@ impl<F: FieldExt> ShockwavePlus<F> {
#[cfg(test)]
mod tests {
use tensor_pcs::rs_config::good_curves::secp256k1::secp256k1_good_curve;
use super::*;
#[test]
fn test_shockwave_plus() {
type F = halo2curves::secp256k1::Fp;
let num_vars = 2usize.pow(7);
let num_vars = 2usize.pow(6);
let num_input = 3;
let l = 10;
let l = 2;
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
let num_rows = 4;
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, num_rows);
let num_cols = det_num_cols(r1cs.z_len(), l);
let k = (num_cols as f64).log2() as usize;
let (good_curve, coset_offset) = secp256k1_good_curve(k + 1);
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, good_curve, coset_offset);
let mut prover_transcript = Transcript::new(b"bench");
let (partial_proof, _) =
ShockwavePlus.prove(&witness, &r1cs.public_input, &mut prover_transcript);