Browse Source

feat: use hardcoded good curves

main
Daniel Tehrani 2 years ago
parent
commit
3546f03844
13 changed files with 631 additions and 100 deletions
  1. +16
    -6
      shockwave_plus/benches/prove.rs
  2. +34
    -45
      shockwave_plus/src/lib.rs
  3. +1
    -1
      tensor_pcs/Cargo.toml
  4. +10
    -12
      tensor_pcs/benches/prove.rs
  5. +4
    -2
      tensor_pcs/src/lib.rs
  6. +4
    -0
      tensor_pcs/src/polynomial/sparse_ml_poly.rs
  7. +19
    -3
      tensor_pcs/src/rs_config/ecfft.rs
  8. +1
    -0
      tensor_pcs/src/rs_config/good_curves/mod.rs
  9. +456
    -0
      tensor_pcs/src/rs_config/good_curves/secp256k1.rs
  10. +1
    -0
      tensor_pcs/src/rs_config/mod.rs
  11. +70
    -30
      tensor_pcs/src/tensor_rs_pcs.rs
  12. +0
    -1
      tensor_pcs/src/tree.rs
  13. +15
    -0
      tensor_pcs/src/utils.rs

+ 16
- 6
shockwave_plus/benches/prove.rs

@ -2,7 +2,8 @@
use criterion::{criterion_group, criterion_main, Criterion};
use shockwave_plus::ShockwavePlus;
use shockwave_plus::R1CS;
use tensor_pcs::Transcript;
use tensor_pcs::rs_config::good_curves::secp256k1::secp256k1_good_curve;
use tensor_pcs::{det_num_cols, Transcript};
fn shockwave_plus_bench(c: &mut Criterion) {
type F = halo2curves::secp256k1::Fp;
@ -15,14 +16,23 @@ fn shockwave_plus_bench(c: &mut Criterion) {
let mut group = c.benchmark_group(format!("ShockwavePlus num_cons: {}", r1cs.num_cons));
let l = 319;
let num_rows = (((2f64 / l as f64).sqrt() * (num_vars as f64).sqrt()) as usize)
.next_power_of_two()
/ 2;
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, num_rows);
let num_cols = det_num_cols(r1cs.z_len(), l);
let (good_curve, coset_offset) =
secp256k1_good_curve((num_cols as f64).log2() as usize + 1);
group.bench_function("config", |b| {
b.iter(|| {
ShockwavePlus::new(r1cs.clone(), l, good_curve, coset_offset);
})
});
let shockwave_plus = ShockwavePlus::new(r1cs.clone(), l, good_curve, coset_offset);
group.bench_function("prove", |b| {
b.iter(|| {
let mut transcript = Transcript::new(b"bench");
ShockwavePlus.prove(&witness, &r1cs.public_input, &mut transcript);
shockwave_plus.prove(&witness, &r1cs.public_input, &mut transcript);
})
});
}

+ 34
- 45
shockwave_plus/src/lib.rs

@ -6,10 +6,10 @@ mod sumcheck;
use ark_std::{end_timer, start_timer};
use serde::{Deserialize, Serialize};
use sumcheck::{SCPhase1Proof, SCPhase2Proof, SumCheckPhase1, SumCheckPhase2};
use tensor_pcs::{ecfft::GoodCurve, *};
// Exports
pub use r1cs::R1CS;
pub use tensor_pcs::*;
#[derive(Serialize, Deserialize)]
pub struct PartialSpartanProof<F: FieldExt> {
@ -30,21 +30,15 @@ pub struct FullSpartanProof {
}
pub struct ShockwavePlus<F: FieldExt> {
pub r1cs: R1CS<F>,
pub pcs_witness: TensorMultilinearPCS<F>,
pub pcs_blinder: TensorMultilinearPCS<F>,
r1cs: R1CS<F>,
pcs: TensorMultilinearPCS<F>,
}
impl<F: FieldExt> ShockwavePlus<F> {
pub fn new(r1cs: R1CS<F>, l: usize, num_rows: usize) -> Self {
let num_cols = r1cs.num_vars / num_rows;
// Make sure that there are enough columns to run the l queries
assert!(num_cols > l);
pub fn new(r1cs: R1CS<F>, l: usize, good_curve: GoodCurve<F>, coset_offset: (F, F)) -> Self {
let expansion_factor = 2;
let ecfft_config = rs_config::ecfft::gen_config(num_cols.next_power_of_two());
let ecfft_config = rs_config::ecfft::gen_config_form_curve(good_curve, coset_offset);
let pcs_config = TensorRSMultilinearPCSConfig::<F> {
expansion_factor,
@ -52,31 +46,21 @@ impl ShockwavePlus {
fft_domain: None,
ecfft_config: Some(ecfft_config),
l,
num_entries: r1cs.num_vars,
num_rows,
};
let pcs_witness = TensorMultilinearPCS::new(pcs_config);
let min_num_entries = r1cs.num_vars.next_power_of_two();
let min_num_cols = pcs_config.num_cols(min_num_entries);
let ecfft_config_blinder =
rs_config::ecfft::gen_config((r1cs.z_len() / num_rows).next_power_of_two());
let pcs_blinder_config = TensorRSMultilinearPCSConfig::<F> {
expansion_factor,
domain_powers: None,
fft_domain: None,
ecfft_config: Some(ecfft_config_blinder),
l,
num_entries: r1cs.z_len(),
num_rows,
};
let max_num_entries = r1cs.z_len().next_power_of_two();
let max_num_cols = pcs_config.num_cols(max_num_entries);
// Make sure that there are enough columns to run the l queries
assert!(min_num_cols > l);
assert_eq!(good_curve.k, (max_num_cols as f64).log2() as usize + 1);
let pcs_blinder = TensorMultilinearPCS::new(pcs_blinder_config);
let pcs = TensorMultilinearPCS::new(pcs_config);
Self {
r1cs,
pcs_witness,
pcs_blinder,
}
Self { r1cs, pcs }
}
pub fn prove(
@ -91,7 +75,7 @@ impl ShockwavePlus {
// Commit the witness polynomial
let comm_witness_timer = start_timer!(|| "Commit witness");
let committed_witness = self.pcs_witness.commit(&witness_poly);
let committed_witness = self.pcs.commit(&witness_poly);
let witness_comm = committed_witness.committed_tree.root;
end_timer!(comm_witness_timer);
@ -125,7 +109,8 @@ impl ShockwavePlus {
tau.clone(),
rx.clone(),
);
let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs_blinder, transcript);
let (sc_proof_1, (v_A, v_B, v_C)) = sc_phase_1.prove(&self.pcs, transcript);
end_timer!(sc_phase_1_timer);
transcript.append_fe(&v_A);
@ -149,14 +134,14 @@ impl ShockwavePlus {
ry.clone(),
);
let sc_proof_2 = sc_phase_2.prove(&self.pcs_blinder, transcript);
let sc_proof_2 = sc_phase_2.prove(&self.pcs, transcript);
end_timer!(sc_phase_2_timer);
let z_open_timer = start_timer!(|| "Open witness poly");
// Prove the evaluation of the polynomial Z(y) at ry
let z_eval_proof =
self.pcs_witness
.open(&committed_witness, &witness_poly, &ry[1..], transcript);
let z_eval_proof = self
.pcs
.open(&committed_witness, &witness_poly, &ry[1..], transcript);
end_timer!(z_open_timer);
// Prove the evaluation of the polynomials A(y), B(y), C(y) at ry
@ -198,7 +183,7 @@ impl ShockwavePlus {
let ex = SumCheckPhase1::verify_round_polys(&partial_proof.sc_proof_1, &rx, rho);
self.pcs_blinder.verify(
self.pcs.verify(
&partial_proof.sc_proof_1.blinder_poly_eval_proof,
transcript,
);
@ -234,7 +219,7 @@ impl ShockwavePlus {
let final_poly_eval =
SumCheckPhase2::verify_round_polys(T_2, &partial_proof.sc_proof_2, &ry);
self.pcs_blinder.verify(
self.pcs.verify(
&partial_proof.sc_proof_2.blinder_poly_eval_proof,
transcript,
);
@ -247,8 +232,7 @@ impl ShockwavePlus {
let B_eval = B_mle.eval(&rx_ry);
let C_eval = C_mle.eval(&rx_ry);
self.pcs_witness
.verify(&partial_proof.z_eval_proof, transcript);
self.pcs.verify(&partial_proof.z_eval_proof, transcript);
let witness_len = self.r1cs.num_vars.next_power_of_two();
let input = (0..self.r1cs.num_input)
@ -270,20 +254,25 @@ impl ShockwavePlus {
#[cfg(test)]
mod tests {
use tensor_pcs::rs_config::good_curves::secp256k1::secp256k1_good_curve;
use super::*;
#[test]
fn test_shockwave_plus() {
type F = halo2curves::secp256k1::Fp;
let num_vars = 2usize.pow(7);
let num_vars = 2usize.pow(6);
let num_input = 3;
let l = 10;
let l = 2;
let (r1cs, witness) = R1CS::<F>::produce_synthetic_r1cs(num_vars, num_input);
let num_rows = 4;
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, num_rows);
let num_cols = det_num_cols(r1cs.z_len(), l);
let k = (num_cols as f64).log2() as usize;
let (good_curve, coset_offset) = secp256k1_good_curve(k + 1);
let ShockwavePlus = ShockwavePlus::new(r1cs.clone(), l, good_curve, coset_offset);
let mut prover_transcript = Transcript::new(b"bench");
let (partial_proof, _) =
ShockwavePlus.prove(&witness, &r1cs.public_input, &mut prover_transcript);

+ 1
- 1
tensor_pcs/Cargo.toml

@ -9,7 +9,7 @@ edition = "2021"
rand = "0.8.5"
serde = { version = "1.0.152", features = ["derive"] }
merlin = "3.0.0"
ecfft = { git = "https://github.com/DanTehrani/ecfft" }
ecfft = { git = "https://github.com/DanTehrani/ecfft", branch = "main" }
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
halo2curves = "0.1.0"

+ 10
- 12
tensor_pcs/benches/prove.rs

@ -15,11 +15,7 @@ fn poly(num_vars: usize) -> SparseMLPoly {
ml_poly
}
fn config_base<F: FieldExt>(ml_poly: &SparseMLPoly<F>) -> TensorRSMultilinearPCSConfig<F> {
let num_vars = ml_poly.num_vars;
let num_evals = 2usize.pow(num_vars as u32);
let num_rows = 2usize.pow((num_vars / 2) as u32);
fn config_base<F: FieldExt>() -> TensorRSMultilinearPCSConfig<F> {
let expansion_factor = 2;
TensorRSMultilinearPCSConfig::<F> {
@ -28,8 +24,6 @@ fn config_base(ml_poly: &SparseMLPoly) -> TensorRSMultilinearPCS
fft_domain: None,
ecfft_config: None,
l: 10,
num_entries: num_evals,
num_rows,
}
}
@ -42,8 +36,10 @@ fn pcs_fft_bench(c: &mut Criterion) {
.map(|i| F::from(i as u64))
.collect::<Vec<F>>();
let mut config = config_base(&ml_poly);
config.fft_domain = Some(rs_config::smooth::gen_config::<F>(config.num_cols()));
let mut config = config_base();
config.fft_domain = Some(rs_config::smooth::gen_config::<F>(
config.num_cols(ml_poly.evals.len()),
));
let mut group = c.benchmark_group("pcs fft");
group.bench_function("prove", |b| {
@ -66,8 +62,10 @@ fn pcs_ecfft_bench(c: &mut Criterion) {
.map(|i| F::from(i as u64))
.collect::<Vec<F>>();
let mut config = config_base(&ml_poly);
config.ecfft_config = Some(rs_config::ecfft::gen_config::<F>(config.num_cols()));
let mut config = config_base();
config.ecfft_config = Some(rs_config::ecfft::gen_config::<F>(
config.num_cols(ml_poly.evals.len()),
));
let mut group = c.benchmark_group("pcs ecfft");
group.bench_function("prove", |b| {
@ -88,6 +86,6 @@ fn set_duration() -> Criterion {
criterion_group! {
name = benches;
config = set_duration();
targets = pcs_fft_bench, pcs_ecfft_bench
targets = pcs_ecfft_bench
}
criterion_main!(benches);

+ 4
- 2
tensor_pcs/src/lib.rs

@ -3,7 +3,7 @@ mod fft;
mod polynomial;
pub mod rs_config;
mod tensor_code;
mod tensor_pcs;
mod tensor_rs_pcs;
mod transcript;
mod tree;
mod utils;
@ -14,7 +14,9 @@ pub trait FieldExt: FromUniformBytes<64, Repr = [u8; 32]> {}
impl FieldExt for halo2curves::secp256k1::Fp {}
impl FieldExt for halo2curves::pasta::Fp {}
pub use ecfft;
pub use polynomial::eq_poly::EqPoly;
pub use polynomial::sparse_ml_poly::SparseMLPoly;
pub use tensor_pcs::{TensorMLOpening, TensorMultilinearPCS, TensorRSMultilinearPCSConfig};
pub use tensor_rs_pcs::{TensorMLOpening, TensorMultilinearPCS, TensorRSMultilinearPCSConfig};
pub use transcript::{AppendToTranscript, Transcript};
pub use utils::{det_num_cols, det_num_rows};

+ 4
- 0
tensor_pcs/src/polynomial/sparse_ml_poly.rs

@ -11,6 +11,10 @@ impl SparseMLPoly {
Self { evals, num_vars }
}
pub fn num_entries(&self) -> usize {
2usize.pow(self.num_vars as u32)
}
pub fn from_dense(dense_evals: Vec<F>) -> Self {
let sparse_evals = dense_evals
.iter()

+ 19
- 3
tensor_pcs/src/rs_config/ecfft.rs

@ -1,5 +1,5 @@
use crate::FieldExt;
use ecfft::{prepare_domain, prepare_matrices, GoodCurve, Matrix2x2};
use ecfft::{find_coset_offset, prepare_domain, prepare_matrices, GoodCurve, Matrix2x2};
#[derive(Clone, Debug)]
pub struct ECFFTConfig<F: FieldExt> {
@ -8,6 +8,20 @@ pub struct ECFFTConfig {
pub inverse_matrices: Vec<Vec<Matrix2x2<F>>>,
}
pub fn gen_config_form_curve<F: FieldExt>(
good_curve: GoodCurve<F>,
coset_offset: (F, F),
) -> ECFFTConfig<F> {
let domain = prepare_domain(good_curve, coset_offset.0, coset_offset.1);
let (matrices, inverse_matrices) = prepare_matrices(&domain);
ECFFTConfig {
domain,
matrices,
inverse_matrices,
}
}
pub fn gen_config<F: FieldExt>(num_cols: usize) -> ECFFTConfig<F> {
assert!(num_cols.is_power_of_two());
let expansion_factor = 2;
@ -15,8 +29,10 @@ pub fn gen_config(num_cols: usize) -> ECFFTConfig {
let k = (codeword_len as f64).log2() as usize;
let good_curve = GoodCurve::find_k(k);
let domain = prepare_domain(good_curve);
let good_curve = GoodCurve::<F>::find_k(k);
let (coset_offset_x, coset_offset_y) =
find_coset_offset(good_curve.a, good_curve.B_sqrt.square());
let domain = prepare_domain(good_curve, coset_offset_x, coset_offset_y);
let (matrices, inverse_matrices) = prepare_matrices(&domain);
ECFFTConfig {

+ 1
- 0
tensor_pcs/src/rs_config/good_curves/mod.rs

@ -0,0 +1 @@
pub mod secp256k1;

+ 456
- 0
tensor_pcs/src/rs_config/good_curves/secp256k1.rs

@ -0,0 +1,456 @@
use ecfft::GoodCurve;
type Fp = halo2curves::secp256k1::Fp;
const CURVE_4_A: Fp = Fp::from_raw([
1924362692430828527,
180888387886949819,
14444912836850558493,
2716763698990320170,
]);
const CURVE_4_B_SQRT: Fp = Fp::from_raw([
10596826214460559417,
9041891995856355984,
392200829566232436,
5616829616257048236,
]);
const CURVE_4_GX: Fp = Fp::from_raw([
3060553808241114122,
4367422627483541323,
1326591990371471461,
1051615568340430255,
]);
const CURVE_4_GY: Fp = Fp::from_raw([
1576479964359531032,
10706990284747222844,
2069836301523772900,
11540652371418823164,
]);
const CURVE_4_CX: Fp = Fp::from_raw([
1394469693679244729,
3481743377114570646,
685293755929734561,
9752242693766949385,
]);
const CURVE_4_CY: Fp = Fp::from_raw([
11112828892610998404,
11816693849252775007,
3142482327686601672,
2138128838908646944,
]);
const CURVE_5_A: Fp = Fp::from_raw([
18402892062958705657,
10955586493449255806,
274692491874833279,
3521647190010012104,
]);
const CURVE_5_B_SQRT: Fp = Fp::from_raw([
13277815540701934041,
10000316683343802069,
13748514902267845339,
5043980866827216326,
]);
const CURVE_5_GX: Fp = Fp::from_raw([
8681597433860724212,
16010850653546434744,
1655308633092053542,
13482638234089226570,
]);
const CURVE_5_GY: Fp = Fp::from_raw([
695535688134352662,
12810977243071276429,
6639318313449462421,
9854099205183828948,
]);
const CURVE_5_CX: Fp = Fp::from_raw([
7244846058583153822,
15236867482366246868,
7610066066648153412,
8717324474930230203,
]);
const CURVE_5_CY: Fp = Fp::from_raw([
15955524643521385563,
14108119042026331605,
8376852394104379031,
5145942493709290957,
]);
const CURVE_6_A: Fp = Fp::from_raw([
11754870036548954207,
1758746815041297131,
5040922207106606105,
6156268686419792864,
]);
const CURVE_6_B_SQRT: Fp = Fp::from_raw([
16551703456907310471,
7307795367003411231,
9107551177630293136,
3643865576794489637,
]);
const CURVE_6_GX: Fp = Fp::from_raw([
3057786712414561431,
3924976238282577064,
1535938406046208114,
4471499328874959330,
]);
const CURVE_6_GY: Fp = Fp::from_raw([
7420330572426678478,
11093910355894798679,
8046171174582240023,
16159208434053522767,
]);
const CURVE_6_CX: Fp = Fp::from_raw([
15232142892107882662,
2997925312254061635,
875684261157844424,
8054980201271915862,
]);
const CURVE_6_CY: Fp = Fp::from_raw([
5573271252396838460,
7659129927758801858,
11224891608690076565,
8114225763096549468,
]);
const CURVE_7_A: Fp = Fp::from_raw([
15267998901538414419,
17985868627099147199,
5570198032670981398,
7365202425498739811,
]);
const CURVE_7_B_SQRT: Fp = Fp::from_raw([
13238569970078865336,
1859729155619525190,
2289004025597154627,
16424324367845100069,
]);
const CURVE_7_GX: Fp = Fp::from_raw([
13067803014914932854,
8460655374139991694,
17522879348989963876,
2592776320942502074,
]);
const CURVE_7_GY: Fp = Fp::from_raw([
17581244616257969879,
13563260062750024799,
17836667944921387338,
5158385585024810784,
]);
const CURVE_7_CX: Fp = Fp::from_raw([
17627362889681060942,
10449394617197091758,
11211669951719111062,
18402164978442722259,
]);
const CURVE_7_CY: Fp = Fp::from_raw([
3635904622687808257,
12660024001564793695,
2997578841449112866,
7489869964282615463,
]);
const CURVE_8_A: Fp = Fp::from_raw([
16479441517948017563,
12244661565122532810,
16423402461885171455,
15804938408404708752,
]);
const CURVE_8_B_SQRT: Fp = Fp::from_raw([
4114471407724985276,
6429895848762172356,
9060307719139806083,
1606308100763345976,
]);
const CURVE_8_GX: Fp = Fp::from_raw([
18005553174453754936,
7879246565041753863,
15708703128473390087,
12948592289805182905,
]);
const CURVE_8_GY: Fp = Fp::from_raw([
2637815016833021192,
5625620963822185667,
15498097759340857613,
2802364189360038003,
]);
const CURVE_8_CX: Fp = Fp::from_raw([
12514982531648064548,
7254771947927897203,
6879061275311364813,
4385541459413917142,
]);
const CURVE_8_CY: Fp = Fp::from_raw([
13726278170638118925,
10016993418218833106,
13091102901943378213,
8612533232618193985,
]);
const CURVE_9_A: Fp = Fp::from_raw([
2821731813563793393,
3977895281010832865,
8603743292399951036,
4645234720790204102,
]);
const CURVE_9_B_SQRT: Fp = Fp::from_raw([
15890535715675950137,
7339610358409226035,
12609222160720627891,
12499110658591842997,
]);
const CURVE_9_GX: Fp = Fp::from_raw([
6103380741459351369,
14746101125474414882,
12417547802268400852,
7335532149994146446,
]);
const CURVE_9_GY: Fp = Fp::from_raw([
4181331351064768648,
16489913493464340135,
7051826832725726336,
887431923330984487,
]);
const CURVE_9_CX: Fp = Fp::from_raw([
9670988472649099633,
15261760137634294840,
2288914830631271678,
6241984859397428357,
]);
const CURVE_9_CY: Fp = Fp::from_raw([
3996701096097868069,
16808707541849580191,
2008740307070264540,
10234541905633632584,
]);
const CURVE_10_A: Fp = Fp::from_raw([
13443933661892288238,
6366097774645666914,
12539700524489124232,
2960403700358460234,
]);
const CURVE_10_B_SQRT: Fp = Fp::from_raw([
11179334656770650694,
12204828351656968056,
17469374953427230415,
2698602761568343027,
]);
const CURVE_10_GX: Fp = Fp::from_raw([
4752429915723436981,
6658961595441054005,
943316193080952835,
10509103062531384873,
]);
const CURVE_10_GY: Fp = Fp::from_raw([
7405820527339739030,
1149755149636620515,
12315441721581649311,
9740641083146831387,
]);
const CURVE_10_CX: Fp = Fp::from_raw([
12133915190353440166,
12735241419273571667,
984598181344074714,
4945074058633718103,
]);
const CURVE_10_CY: Fp = Fp::from_raw([
7787944055361603336,
16188630343349344241,
1798488611520969499,
15905180573830923441,
]);
const CURVE_11_A: Fp = Fp::from_raw([
44690967250983077,
13024355091469571869,
2426866618505792061,
5439410159441159777,
]);
const CURVE_11_B_SQRT: Fp = Fp::from_raw([
2482839174035592440,
13977599229562359858,
9165253311652858048,
11796280965050311461,
]);
const CURVE_11_GX: Fp = Fp::from_raw([
3785100838262116535,
14366163517008314631,
6520093107874784461,
1432940500835404998,
]);
const CURVE_11_GY: Fp = Fp::from_raw([
15446934078954168044,
13724149936204307181,
291296515805666972,
17295416299404581082,
]);
const CURVE_11_CX: Fp = Fp::from_raw([
3987179730389290606,
12765099312359453542,
14085665078244679772,
1158541383839945849,
]);
const CURVE_11_CY: Fp = Fp::from_raw([
2812404283588715887,
10748530967036022352,
15279323815639380689,
7472866256744067949,
]);
pub fn secp256k1_good_curve(k: usize) -> (GoodCurve<Fp>, (Fp, Fp)) {
if k == 4 {
(
GoodCurve::new(CURVE_4_A, CURVE_4_B_SQRT, CURVE_4_GX, CURVE_4_GY, k),
(CURVE_4_CX, CURVE_4_CY),
)
} else if k == 5 {
(
GoodCurve::new(CURVE_5_A, CURVE_5_B_SQRT, CURVE_5_GX, CURVE_5_GY, k),
(CURVE_5_CX, CURVE_5_CY),
)
} else if k == 6 {
(
GoodCurve::new(CURVE_6_A, CURVE_6_B_SQRT, CURVE_6_GX, CURVE_6_GY, k),
(CURVE_6_CX, CURVE_6_CY),
)
} else if k == 7 {
(
GoodCurve::new(CURVE_7_A, CURVE_7_B_SQRT, CURVE_7_GX, CURVE_7_GY, k),
(CURVE_7_CX, CURVE_7_CY),
)
} else if k == 8 {
(
GoodCurve::new(CURVE_8_A, CURVE_8_B_SQRT, CURVE_8_GX, CURVE_8_GY, k),
(CURVE_8_CX, CURVE_8_CY),
)
} else if k == 9 {
(
GoodCurve::new(CURVE_9_A, CURVE_9_B_SQRT, CURVE_9_GX, CURVE_9_GY, k),
(CURVE_9_CX, CURVE_9_CY),
)
} else if k == 10 {
(
GoodCurve::new(CURVE_10_A, CURVE_10_B_SQRT, CURVE_10_GX, CURVE_10_GY, k),
(CURVE_10_CX, CURVE_10_CY),
)
} else if k == 11 {
(
GoodCurve::new(CURVE_11_A, CURVE_11_B_SQRT, CURVE_11_GX, CURVE_11_GY, k),
(CURVE_11_CX, CURVE_11_CY),
)
} else {
panic!("k must be between 4 and 11")
}
}
#[cfg(test)]
mod tests {
use ecfft::{find_coset_offset, GoodCurve};
type F = halo2curves::secp256k1::Fp;
fn to_limbs(x: F) -> [u64; 4] {
let bytes = x.to_bytes();
let mut limbs = [0u64; 4];
for i in 0..4 {
let mut limb_i = 0;
for j in 0..8 {
limb_i += (bytes[8 * i + j] as u64) << (8 * j);
}
limbs[i] = limb_i;
}
limbs
}
#[test]
fn find_curves() {
// We expect the tensor-IOP to use a square matrix for now,
// so we only need to find curves with the square of the number
// of evaluations
for k in 4..12 {
let curve = GoodCurve::<F>::find_k(k);
let (coset_offset_x, coset_offset_y) =
find_coset_offset(curve.a, curve.B_sqrt.square());
println!(
"const CURVE_{}_A: Fp = Fp::from_raw(
{:?},
);
const CURVE_{}_B_SQRT: Fp = Fp::from_raw(
{:?},
);
const CURVE_{}_GX: Fp = Fp::from_raw(
{:?},
);
const CURVE_{}_GY: Fp = Fp::from_raw(
{:?}
);
const CURVE_{}_CX: Fp = Fp::from_raw(
{:?},
);
const CURVE_{}_CY: Fp = Fp::from_raw(
{:?},
);
",
k,
to_limbs(curve.a),
k,
to_limbs(curve.B_sqrt),
k,
to_limbs(curve.gx),
k,
to_limbs(curve.gy),
k,
to_limbs(coset_offset_x),
k,
to_limbs(coset_offset_y),
);
}
}
}

+ 1
- 0
tensor_pcs/src/rs_config/mod.rs

@ -1,3 +1,4 @@
pub mod ecfft;
pub mod good_curves;
pub mod naive;
pub mod smooth;

tensor_pcs/src/tensor_pcs.rs → tensor_pcs/src/tensor_rs_pcs.rs

@ -9,7 +9,7 @@ use crate::polynomial::eq_poly::EqPoly;
use crate::polynomial::sparse_ml_poly::SparseMLPoly;
use crate::tensor_code::TensorCode;
use crate::transcript::Transcript;
use crate::utils::{dot_prod, hash_all, rlc_rows, sample_indices};
use crate::utils::{det_num_cols, det_num_rows, dot_prod, hash_all, rlc_rows, sample_indices};
use super::tensor_code::CommittedTensorCode;
@ -20,17 +20,15 @@ pub struct TensorRSMultilinearPCSConfig {
pub fft_domain: Option<Vec<F>>,
pub ecfft_config: Option<ECFFTConfig<F>>,
pub l: usize,
pub num_entries: usize,
pub num_rows: usize,
}
impl<F: FieldExt> TensorRSMultilinearPCSConfig<F> {
pub fn num_cols(&self) -> usize {
self.num_entries / self.num_rows()
pub fn num_cols(&self, num_entries: usize) -> usize {
det_num_cols(num_entries, self.l)
}
pub fn num_rows(&self) -> usize {
self.num_rows
pub fn num_rows(&self, num_entries: usize) -> usize {
det_num_rows(num_entries, self.l)
}
}
@ -51,6 +49,7 @@ pub struct TensorMLOpening {
pub test_r_prime: Vec<F>,
pub eval_r_prime: Vec<F>,
pub eval_u_prime: Vec<F>,
pub poly_num_vars: usize,
}
impl<F: FieldExt> TensorMultilinearPCS<F> {
@ -61,7 +60,10 @@ impl TensorMultilinearPCS {
pub fn commit(&self, poly: &SparseMLPoly<F>) -> CommittedTensorCode<F> {
// Merkle commit to the evaluations of the polynomial
let tensor_code = self.encode_zk(poly);
let tree = tensor_code.commit(self.config.num_cols(), self.config.num_rows());
let tree = tensor_code.commit(
self.config.num_cols(poly.num_entries()),
self.config.num_rows(poly.num_entries()),
);
tree
}
@ -72,10 +74,16 @@ impl TensorMultilinearPCS {
point: &[F],
transcript: &mut Transcript<F>,
) -> TensorMLOpening<F> {
let num_cols = self.config.num_cols();
let num_rows = self.config.num_rows();
let num_cols = self.config.num_cols(poly.num_entries());
let num_rows = self.config.num_rows(poly.num_entries());
debug_assert_eq!(poly.num_vars, point.len());
let mut padded_evals = poly.evals.clone();
padded_evals.resize(
num_cols * num_rows,
(2usize.pow(poly.num_vars as u32), F::ZERO),
);
// ########################################
// Testing phase
// Prove the consistency between the random linear combination of the evaluation tensor (u_prime)
@ -87,7 +95,7 @@ impl TensorMultilinearPCS {
let u = (0..num_rows)
.map(|i| {
poly.evals[(i * num_cols)..((i + 1) * num_cols)]
padded_evals[(i * num_cols)..((i + 1) * num_cols)]
.iter()
.map(|entry| entry.1)
.collect::<Vec<F>>()
@ -143,14 +151,16 @@ impl TensorMultilinearPCS {
base_opening: BaseOpening {
hashes: u_hat_comm.committed_tree.column_roots.clone(),
},
poly_num_vars: poly.num_vars,
}
}
}
impl<F: FieldExt> TensorMultilinearPCS<F> {
pub fn verify(&self, opening: &TensorMLOpening<F>, transcript: &mut Transcript<F>) {
let num_rows = self.config.num_rows();
let num_cols = self.config.num_cols();
let poly_num_entries = 2usize.pow(opening.poly_num_vars as u32);
let num_rows = self.config.num_rows(poly_num_entries);
let num_cols = self.config.num_cols(poly_num_entries);
// Verify the base opening
let base_opening = &opening.base_opening;
@ -249,7 +259,13 @@ impl TensorMultilinearPCS {
}
fn rs_encode(&self, message: &[F]) -> Vec<F> {
let mut padded_message = message.to_vec();
padded_message.resize(message.len().next_power_of_two(), F::ZERO);
let codeword_len = padded_message.len() * self.config.expansion_factor;
let codeword_len_log2 = (codeword_len as f64).log2() as usize;
let codeword = if self.config.fft_domain.is_some() {
// TODO: Resize the domain according to the message length
let fft_domain = self.config.fft_domain.as_ref().unwrap();
let mut padded_coeffs = message.clone().to_vec();
padded_coeffs.resize(fft_domain.len(), F::ZERO);
@ -257,13 +273,27 @@ impl TensorMultilinearPCS {
codeword
} else if self.config.ecfft_config.is_some() {
let ecfft_config = self.config.ecfft_config.as_ref().unwrap();
let mut ecfft_config = self.config.ecfft_config.clone().unwrap();
// Resize the domain to the correct size
let config_domain_size = ecfft_config.domain.len();
assert!(config_domain_size >= codeword_len_log2 - 1);
ecfft_config.domain =
ecfft_config.domain[(config_domain_size - (codeword_len_log2 - 1))..].to_vec();
ecfft_config.matrices =
ecfft_config.matrices[(config_domain_size - (codeword_len_log2 - 1))..].to_vec();
ecfft_config.inverse_matrices = ecfft_config.inverse_matrices
[(config_domain_size - (codeword_len_log2 - 1))..]
.to_vec();
assert_eq!(
message.len() * self.config.expansion_factor,
padded_message.len() * self.config.expansion_factor,
ecfft_config.domain[0].len()
);
let extended_evals = extend(
message,
&padded_message,
&ecfft_config.domain,
&ecfft_config.matrices,
&ecfft_config.inverse_matrices,
@ -273,6 +303,7 @@ impl TensorMultilinearPCS {
let codeword = [message.to_vec(), extended_evals].concat();
codeword
} else {
// TODO: Resize the domain according to the message length
let domain_powers = self.config.domain_powers.as_ref().unwrap();
assert_eq!(message.len(), domain_powers[0].len());
assert_eq!(
@ -311,12 +342,20 @@ impl TensorMultilinearPCS {
}
fn encode_zk(&self, poly: &SparseMLPoly<F>) -> TensorCode<F> {
let num_rows = self.config.num_rows();
let num_cols = self.config.num_cols();
let num_rows = self.config.num_rows(poly.num_entries());
let num_cols = self.config.num_cols(poly.num_entries());
// Pad the sparse evaluations with zeros
let mut evals = poly.evals.clone();
evals.resize(
num_cols * num_rows,
(2usize.pow(poly.num_vars as u32), F::ZERO),
);
debug_assert_eq!(evals.len(), num_cols * num_rows);
let codewords = (0..num_rows)
.map(|i| {
poly.evals[i * num_cols..(i + 1) * num_cols]
evals[i * num_cols..(i + 1) * num_cols]
.iter()
.map(|entry| entry.1)
.collect::<Vec<F>>()
@ -330,10 +369,12 @@ impl TensorMultilinearPCS {
#[cfg(test)]
mod tests {
use ::ecfft::find_coset_offset;
use super::*;
use crate::rs_config::{ecfft, naive, smooth};
use crate::rs_config::{ecfft, good_curves::secp256k1::secp256k1_good_curve, naive, smooth};
const TEST_NUM_VARS: usize = 10;
const TEST_NUM_VARS: usize = 8;
const TEST_L: usize = 10;
fn test_poly<F: FieldExt>() -> SparseMLPoly<F> {
@ -364,10 +405,6 @@ mod tests {
}
fn config_base<F: FieldExt>(ml_poly: &SparseMLPoly<F>) -> TensorRSMultilinearPCSConfig<F> {
let num_vars = ml_poly.num_vars;
let num_evals = 2usize.pow(num_vars as u32);
let num_rows = 2usize.pow((num_vars / 2) as u32);
let expansion_factor = 2;
TensorRSMultilinearPCSConfig::<F> {
@ -376,8 +413,6 @@ mod tests {
fft_domain: None,
ecfft_config: None,
l: TEST_L,
num_entries: num_evals,
num_rows,
}
}
@ -387,7 +422,7 @@ mod tests {
// FFT config
let ml_poly = test_poly();
let mut config = config_base(&ml_poly);
config.fft_domain = Some(smooth::gen_config(config.num_cols()));
config.fft_domain = Some(smooth::gen_config(config.num_cols(ml_poly.num_entries())));
// Test FFT PCS
let tensor_pcs_fft = TensorMultilinearPCS::<F>::new(config);
@ -400,7 +435,12 @@ mod tests {
let ml_poly = test_poly();
let mut config = config_base(&ml_poly);
config.ecfft_config = Some(ecfft::gen_config(config.num_cols()));
let num_cols = config.num_cols(ml_poly.num_entries());
let k = ((num_cols * config.expansion_factor).next_power_of_two() as f64).log2() as usize;
let (curve, coset_offset) = secp256k1_good_curve(k);
config.ecfft_config = Some(ecfft::gen_config_form_curve(curve, coset_offset));
// Test FFT PCS
let tensor_pcs_ecf = TensorMultilinearPCS::<F>::new(config);
@ -415,7 +455,7 @@ mod tests {
// Naive config
let mut config = config_base(&ml_poly);
config.domain_powers = Some(naive::gen_config(config.num_cols()));
config.domain_powers = Some(naive::gen_config(config.num_cols(ml_poly.num_entries())));
// Test FFT PCS
let tensor_pcs_naive = TensorMultilinearPCS::<F>::new(config);

+ 0
- 1
tensor_pcs/src/tree.rs

@ -12,7 +12,6 @@ pub struct CommittedMerkleTree {
impl<F: FieldExt> CommittedMerkleTree<F> {
pub fn from_leaves(leaves: Vec<F>, num_cols: usize) -> Self {
let n = leaves.len();
debug_assert!(n.is_power_of_two());
let num_rows = n / num_cols;
assert!(num_rows & 1 == 0); // Number of rows must be even

+ 15
- 0
tensor_pcs/src/utils.rs

@ -75,6 +75,21 @@ pub fn sample_indices(
indices
}
pub fn det_num_cols(num_entries: usize, l: usize) -> usize {
let num_entries_sqrt = (num_entries as f64).sqrt() as usize;
// The number of columns must be a power of two
// to tensor-query the polynomial evaluation
let num_cols = std::cmp::max(num_entries_sqrt, l).next_power_of_two();
num_cols
}
pub fn det_num_rows(num_entries: usize, l: usize) -> usize {
// The number of rows must be a power of two
// to tensor-query the polynomial evaluation
let num_rows = (num_entries / det_num_cols(num_entries, l)).next_power_of_two();
num_rows
}
#[cfg(test)]
mod tests {
use super::*;

Loading…
Cancel
Save