diff --git a/Cargo.toml b/Cargo.toml index daccafc..1d758ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,9 @@ serde = { version = "1.0", features = [ "derive" ], default-features = false, op winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false } winter_math = { version = "0.7", package = "winter-math", default-features = false } winter_utils = { version = "0.7", package = "winter-utils", default-features = false } +rayon = "1.8.0" +rand = "0.8.4" +rand_core = { version = "0.5", default-features = false } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/gkr/circuit/mod.rs b/src/gkr/circuit/mod.rs new file mode 100644 index 0000000..d599081 --- /dev/null +++ b/src/gkr/circuit/mod.rs @@ -0,0 +1,977 @@ +use alloc::sync::Arc; +use winter_crypto::{ElementHasher, RandomCoin}; +use winter_math::fields::f64::BaseElement; +use winter_math::FieldElement; + +use crate::gkr::multivariate::{ + ComposedMultiLinearsOracle, EqPolynomial, GkrCompositionVanilla, MultiLinearOracle, +}; +use crate::gkr::sumcheck::{sum_check_verify, Claim}; + +use super::multivariate::{ + gen_plain_gkr_oracle, gkr_composition_from_composition_polys, ComposedMultiLinears, + CompositionPolynomial, MultiLinear, +}; +use super::sumcheck::{ + sum_check_prove, sum_check_verify_and_reduce, FinalEvaluationClaim, + PartialProof as SumcheckInstanceProof, RoundProof as SumCheckRoundProof, Witness, +}; + +/// Layered circuit for computing a sum of fractions. +/// +/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / (c * d) +/// which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the `FractionalSumCircuit` +/// is built. +/// TODO: Swap 1 and 0 +#[derive(Debug)] +pub struct FractionalSumCircuit { + p_1_vec: Vec>, + p_0_vec: Vec>, + q_1_vec: Vec>, + q_0_vec: Vec>, +} + +impl FractionalSumCircuit { + /// Computes The values of the gates outputs for each of the layers of the fractional sum circuit. + pub fn new_(num_den: &Vec>) -> Self { + let mut p_1_vec: Vec> = Vec::new(); + let mut p_0_vec: Vec> = Vec::new(); + let mut q_1_vec: Vec> = Vec::new(); + let mut q_0_vec: Vec> = Vec::new(); + + let num_layers = num_den[0].len().ilog2() as usize; + + p_1_vec.push(num_den[0].to_owned()); + p_0_vec.push(num_den[1].to_owned()); + q_1_vec.push(num_den[2].to_owned()); + q_0_vec.push(num_den[3].to_owned()); + + for i in 0..num_layers { + let (output_p_1, output_p_0, output_q_1, output_q_0) = + FractionalSumCircuit::compute_layer( + &p_1_vec[i], + &p_0_vec[i], + &q_1_vec[i], + &q_0_vec[i], + ); + p_1_vec.push(output_p_1); + p_0_vec.push(output_p_0); + q_1_vec.push(output_q_1); + q_0_vec.push(output_q_0); + } + + FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec } + } + + /// Compute the output values of the layer given a set of input values + fn compute_layer( + inp_p_1: &MultiLinear, + inp_p_0: &MultiLinear, + inp_q_1: &MultiLinear, + inp_q_0: &MultiLinear, + ) -> (MultiLinear, MultiLinear, MultiLinear, MultiLinear) { + let len = inp_q_1.len(); + let outp_p_1 = (0..len / 2) + .map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i]) + .collect::>(); + let outp_p_0 = (len / 2..len) + .map(|i| inp_p_1[i] * inp_q_0[i] + inp_p_0[i] * inp_q_1[i]) + .collect::>(); + let outp_q_1 = (0..len / 2).map(|i| inp_q_1[i] * inp_q_0[i]).collect::>(); + let outp_q_0 = (len / 2..len).map(|i| inp_q_1[i] * inp_q_0[i]).collect::>(); + + ( + MultiLinear::new(outp_p_1), + MultiLinear::new(outp_p_0), + MultiLinear::new(outp_q_1), + MultiLinear::new(outp_q_0), + ) + } + + /// Computes The values of the gates outputs for each of the layers of the fractional sum circuit. + pub fn new(poly: &MultiLinear) -> Self { + let mut p_1_vec: Vec> = Vec::new(); + let mut p_0_vec: Vec> = Vec::new(); + let mut q_1_vec: Vec> = Vec::new(); + let mut q_0_vec: Vec> = Vec::new(); + + let num_layers = poly.len().ilog2() as usize - 1; + let (output_p, output_q) = poly.split(poly.len() / 2); + let (output_p_1, output_p_0) = output_p.split(output_p.len() / 2); + let (output_q_1, output_q_0) = output_q.split(output_q.len() / 2); + + p_1_vec.push(output_p_1); + p_0_vec.push(output_p_0); + q_1_vec.push(output_q_1); + q_0_vec.push(output_q_0); + + for i in 0..num_layers - 1 { + let (output_p_1, output_p_0, output_q_1, output_q_0) = + FractionalSumCircuit::compute_layer( + &p_1_vec[i], + &p_0_vec[i], + &q_1_vec[i], + &q_0_vec[i], + ); + p_1_vec.push(output_p_1); + p_0_vec.push(output_p_0); + q_1_vec.push(output_q_1); + q_0_vec.push(output_q_0); + } + + FractionalSumCircuit { p_1_vec, p_0_vec, q_1_vec, q_0_vec } + } + + /// Given a value r, computes the evaluation of the last layer at r when interpreted as (two) + /// multilinear polynomials. + pub fn evaluate(&self, r: E) -> (E, E) { + let len = self.p_1_vec.len(); + assert_eq!(self.p_1_vec[len - 1].num_variables(), 0); + assert_eq!(self.p_0_vec[len - 1].num_variables(), 0); + assert_eq!(self.q_1_vec[len - 1].num_variables(), 0); + assert_eq!(self.q_0_vec[len - 1].num_variables(), 0); + + let mut p_1 = self.p_1_vec[len - 1].clone(); + p_1.extend(&self.p_0_vec[len - 1]); + let mut q_1 = self.q_1_vec[len - 1].clone(); + q_1.extend(&self.q_0_vec[len - 1]); + + (p_1.evaluate(&[r]), q_1.evaluate(&[r])) + } +} + +/// A proof for reducing a claim on the correctness of the output of a layer to that of: +/// +/// 1. Correctness of a sumcheck proof on the claimed output. +/// 2. Correctness of the evaluation of the input (to the said layer) at a random point when +/// interpreted as multilinear polynomial. +/// +/// The verifier will then have to work backward and: +/// +/// 1. Verify that the sumcheck proof is valid. +/// 2. Recurse on the (claimed evaluations) using the same approach as above. +/// +/// Note that the following struct batches proofs for many circuits of the same type that +/// are independent i.e., parallel. +#[derive(Debug)] +pub struct LayerProof { + pub proof: SumcheckInstanceProof, + pub claims_sum_p1: E, + pub claims_sum_p0: E, + pub claims_sum_q1: E, + pub claims_sum_q0: E, +} + +#[allow(dead_code)] +impl + 'static> LayerProof { + /// Checks the validity of a `LayerProof`. + /// + /// It first reduces the 2 claims to 1 claim using randomness and then checks that the sumcheck + /// protocol was correctly executed. + /// + /// The method outputs: + /// + /// 1. A vector containing the randomness sent by the verifier throughout the course of the + /// sum-check protocol. + /// 2. The (claimed) evaluation of the inner polynomial (i.e., the one being summed) at the this random vector. + /// 3. The random value used in the 2-to-1 reduction of the 2 sumchecks. + pub fn verify_sum_check_before_last< + C: RandomCoin, + H: ElementHasher, + >( + &self, + claim: (E, E), + num_rounds: usize, + transcript: &mut C, + ) -> ((E, Vec), E) { + // Absorb the claims + let data = vec![claim.0, claim.1]; + transcript.reseed(H::hash_elements(&data)); + + // Squeeze challenge to reduce two sumchecks to one + let r_sum_check: E = transcript.draw().unwrap(); + + // Run the sumcheck protocol + + // Given r_sum_check and claim, we create a Claim with the GKR composer and then call the generic sum-check verifier + let reduced_claim = claim.0 + claim.1 * r_sum_check; + + // Create vanilla oracle + let oracle = gen_plain_gkr_oracle(num_rounds, r_sum_check); + + // Create sum-check claim + let transformed_claim = Claim { + sum_value: reduced_claim, + polynomial: oracle, + }; + + let reduced_gkr_claim = + sum_check_verify_and_reduce(&transformed_claim, self.proof.clone(), transcript); + + (reduced_gkr_claim, r_sum_check) + } +} + +#[derive(Debug)] +pub struct GkrClaim { + evaluation_point: Vec, + claimed_evaluation: (E, E), +} + +#[derive(Debug)] +pub struct CircuitProof { + pub proof: Vec>, +} + +impl + 'static> CircuitProof { + pub fn prove< + C: RandomCoin, + H: ElementHasher, + >( + circuit: &mut FractionalSumCircuit, + transcript: &mut C, + ) -> (Self, Vec, Vec>) { + let mut proof_layers: Vec> = Vec::new(); + let num_layers = circuit.p_0_vec.len(); + + let data = vec![ + circuit.p_1_vec[num_layers - 1][0], + circuit.p_0_vec[num_layers - 1][0], + circuit.q_1_vec[num_layers - 1][0], + circuit.q_0_vec[num_layers - 1][0], + ]; + transcript.reseed(H::hash_elements(&data)); + + // Challenge to reduce p1, p0, q1, q0 to pr, qr + let r_cord = transcript.draw().unwrap(); + + // Compute the (2-to-1 folded) claim + let mut claim = circuit.evaluate(r_cord); + let mut all_rand = Vec::new(); + + let mut rand = Vec::new(); + rand.push(r_cord); + for layer_id in (0..num_layers - 1).rev() { + let len = circuit.p_0_vec[layer_id].len(); + + // Construct the Lagrange kernel evaluated at previous GKR round randomness. + // TODO: Treat the direction of doing sum-check more robustly. + let mut rand_reversed = rand.clone(); + rand_reversed.reverse(); + let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations(); + let mut poly_x = MultiLinear::from_values(&eq_evals); + assert_eq!(poly_x.len(), len); + + let num_rounds = poly_x.len().ilog2() as usize; + + // 1. A is a polynomial containing the evaluations `p_1`. + // 2. B is a polynomial containing the evaluations `p_0`. + // 3. C is a polynomial containing the evaluations `q_1`. + // 4. D is a polynomial containing the evaluations `q_0`. + let poly_a: &mut MultiLinear; + let poly_b: &mut MultiLinear; + let poly_c: &mut MultiLinear; + let poly_d: &mut MultiLinear; + poly_a = &mut circuit.p_1_vec[layer_id]; + poly_b = &mut circuit.p_0_vec[layer_id]; + poly_c = &mut circuit.q_1_vec[layer_id]; + poly_d = &mut circuit.q_0_vec[layer_id]; + + let poly_vec_par = (poly_a, poly_b, poly_c, poly_d, &mut poly_x); + + // The (non-linear) polynomial combining the multilinear polynomials + let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E { + (*a * *d + *b * *c + *rho * *c * *d) * *x + }; + + // Run the sumcheck protocol + let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::( + claim, + num_rounds, + poly_vec_par, + comb_func, + transcript, + ); + + let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) = + claims_sum; + + let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0]; + transcript.reseed(H::hash_elements(&data)); + + // Produce a random challenge to condense claims into a single claim + let r_layer = transcript.draw().unwrap(); + + claim = ( + claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1), + claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1), + ); + + // Collect the randomness used for the current layer in order to construct the random + // point where the input multilinear polynomials were evaluated. + let mut ext = rand_sumcheck; + ext.push(r_layer); + all_rand.push(rand); + rand = ext; + + proof_layers.push(LayerProof { + proof, + claims_sum_p1, + claims_sum_p0, + claims_sum_q1, + claims_sum_q0, + }); + } + + (CircuitProof { proof: proof_layers }, rand, all_rand) + } + + pub fn prove_virtual_bus< + C: RandomCoin, + H: ElementHasher, + >( + composition_polys: Vec>>>, + mls: &mut Vec>, + transcript: &mut C, + ) -> (Vec, Self, super::sumcheck::FullProof) { + let num_evaluations = 1 << mls[0].num_variables(); + + // I) Evaluate the numerators and denominators over the boolean hyper-cube + let mut num_den: Vec> = vec![vec![]; 4]; + for i in 0..num_evaluations { + for j in 0..4 { + let query: Vec = mls.iter().map(|ml| ml[i]).collect(); + + composition_polys[j].iter().for_each(|c| { + let evaluation = c.as_ref().evaluate(&query); + num_den[j].push(evaluation); + }); + } + } + + // II) Evaluate the GKR fractional sum circuit + let input: Vec> = + (0..4).map(|i| MultiLinear::from_values(&num_den[i])).collect(); + let mut circuit = FractionalSumCircuit::new_(&input); + + // III) Run the GKR prover for all layers except the last one + let (gkr_proofs, GkrClaim { evaluation_point, claimed_evaluation }) = + CircuitProof::prove_before_final(&mut circuit, transcript); + + // IV) Run the sum-check prover for the last GKR layer counting backwards i.e., first layer + // in the circuit. + + // 1) Build the EQ polynomial (Lagrange kernel) at the randomness sampled during the previous + // sum-check protocol run + let mut rand_reversed = evaluation_point.clone(); + rand_reversed.reverse(); + let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations(); + let poly_x = MultiLinear::from_values(&eq_evals); + + // 2) Add the Lagrange kernel to the list of MLs + mls.push(poly_x); + + // 3) Absorb the final sum-check claims and generate randomness for 2-to-1 sum-check reduction + let data = vec![claimed_evaluation.0, claimed_evaluation.1]; + transcript.reseed(H::hash_elements(&data)); + // Squeeze challenge to reduce two sumchecks to one + let r_sum_check = transcript.draw().unwrap(); + let reduced_claim = claimed_evaluation.0 + claimed_evaluation.1 * r_sum_check; + + // 4) Create the composed ML representing the numerators and denominators of the topmost GKR layer + let gkr_final_composed_ml = gkr_composition_from_composition_polys( + &composition_polys, + r_sum_check, + 1 << mls[0].num_variables, + ); + let composed_ml = + ComposedMultiLinears::new(Arc::new(gkr_final_composed_ml.clone()), mls.to_vec()); + + // 5) Create the composed ML oracle. This will be used for verifying the FinalEvaluationClaim downstream + // TODO: This should be an input to the current function. + // TODO: Make MultiLinearOracle a variant in an enum so that it is possible to capture other types of oracles. + // For example, shifts of polynomials, Lagrange kernels at a random point or periodic (transparent) polynomials. + let left_num_oracle = MultiLinearOracle { id: 0 }; + let right_num_oracle = MultiLinearOracle { id: 1 }; + let left_denom_oracle = MultiLinearOracle { id: 2 }; + let right_denom_oracle = MultiLinearOracle { id: 3 }; + let eq_oracle = MultiLinearOracle { id: 4 }; + let composed_ml_oracle = ComposedMultiLinearsOracle { + composer: (Arc::new(gkr_final_composed_ml.clone())), + multi_linears: vec![ + eq_oracle, + left_num_oracle, + right_num_oracle, + left_denom_oracle, + right_denom_oracle, + ], + }; + + // 6) Create the claim for the final sum-check protocol. + let claim = Claim { + sum_value: reduced_claim, + polynomial: composed_ml_oracle.clone(), + }; + + // 7) Create the witness for the sum-check claim. + let witness = Witness { polynomial: composed_ml }; + let output = sum_check_prove(&claim, composed_ml_oracle, witness, transcript); + + // 8) Create the claimed output of the circuit. + let circuit_outputs = vec![ + circuit.p_1_vec.last().unwrap()[0], + circuit.p_0_vec.last().unwrap()[0], + circuit.q_1_vec.last().unwrap()[0], + circuit.q_0_vec.last().unwrap()[0], + ]; + + // 9) Return: + // 1. The claimed circuit outputs. + // 2. GKR proofs of all circuit layers except the initial layer. + // 3. Output of the final sum-check protocol. + (circuit_outputs, gkr_proofs, output) + } + + pub fn prove_before_final< + C: RandomCoin, + H: ElementHasher, + >( + sum_circuits: &mut FractionalSumCircuit, + transcript: &mut C, + ) -> (Self, GkrClaim) { + let mut proof_layers: Vec> = Vec::new(); + let num_layers = sum_circuits.p_0_vec.len(); + + let data = vec![ + sum_circuits.p_1_vec[num_layers - 1][0], + sum_circuits.p_0_vec[num_layers - 1][0], + sum_circuits.q_1_vec[num_layers - 1][0], + sum_circuits.q_0_vec[num_layers - 1][0], + ]; + transcript.reseed(H::hash_elements(&data)); + + // Challenge to reduce p1, p0, q1, q0 to pr, qr + let r_cord = transcript.draw().unwrap(); + + // Compute the (2-to-1 folded) claim + let mut claims_to_verify = sum_circuits.evaluate(r_cord); + let mut all_rand = Vec::new(); + + let mut rand = Vec::new(); + rand.push(r_cord); + for layer_id in (1..num_layers - 1).rev() { + let len = sum_circuits.p_0_vec[layer_id].len(); + + // Construct the Lagrange kernel evaluated at previous GKR round randomness. + // TODO: Treat the direction of doing sum-check more robustly. + let mut rand_reversed = rand.clone(); + rand_reversed.reverse(); + let eq_evals = EqPolynomial::new(rand_reversed.clone()).evaluations(); + let mut poly_x = MultiLinear::from_values(&eq_evals); + assert_eq!(poly_x.len(), len); + + let num_rounds = poly_x.len().ilog2() as usize; + + // 1. A is a polynomial containing the evaluations `p_1`. + // 2. B is a polynomial containing the evaluations `p_0`. + // 3. C is a polynomial containing the evaluations `q_1`. + // 4. D is a polynomial containing the evaluations `q_0`. + let poly_a: &mut MultiLinear; + let poly_b: &mut MultiLinear; + let poly_c: &mut MultiLinear; + let poly_d: &mut MultiLinear; + poly_a = &mut sum_circuits.p_1_vec[layer_id]; + poly_b = &mut sum_circuits.p_0_vec[layer_id]; + poly_c = &mut sum_circuits.q_1_vec[layer_id]; + poly_d = &mut sum_circuits.q_0_vec[layer_id]; + + let poly_vec = (poly_a, poly_b, poly_c, poly_d, &mut poly_x); + + let claim = claims_to_verify; + + // The (non-linear) polynomial combining the multilinear polynomials + let comb_func = |a: &E, b: &E, c: &E, d: &E, x: &E, rho: &E| -> E { + (*a * *d + *b * *c + *rho * *c * *d) * *x + }; + + // Run the sumcheck protocol + let (proof, rand_sumcheck, claims_sum) = sum_check_prover_gkr_before_last::( + claim, num_rounds, poly_vec, comb_func, transcript, + ); + + let (claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0, _claims_eq) = + claims_sum; + + let data = vec![claims_sum_p1, claims_sum_p0, claims_sum_q1, claims_sum_q0]; + transcript.reseed(H::hash_elements(&data)); + + // Produce a random challenge to condense claims into a single claim + let r_layer = transcript.draw().unwrap(); + + claims_to_verify = ( + claims_sum_p1 + r_layer * (claims_sum_p0 - claims_sum_p1), + claims_sum_q1 + r_layer * (claims_sum_q0 - claims_sum_q1), + ); + + // Collect the randomness used for the current layer in order to construct the random + // point where the input multilinear polynomials were evaluated. + let mut ext = rand_sumcheck; + ext.push(r_layer); + all_rand.push(rand); + rand = ext; + + proof_layers.push(LayerProof { + proof, + claims_sum_p1, + claims_sum_p0, + claims_sum_q1, + claims_sum_q0, + }); + } + let gkr_claim = GkrClaim { + evaluation_point: rand.clone(), + claimed_evaluation: claims_to_verify, + }; + + (CircuitProof { proof: proof_layers }, gkr_claim) + } + + pub fn verify< + C: RandomCoin, + H: ElementHasher, + >( + &self, + claims_sum_vec: &(E, E, E, E), + transcript: &mut C, + ) -> ((E, E), Vec) { + let num_layers = self.proof.len() as usize - 1; + let mut rand: Vec = Vec::new(); + + let data = vec![claims_sum_vec.0, claims_sum_vec.1, claims_sum_vec.2, claims_sum_vec.3]; + transcript.reseed(H::hash_elements(&data)); + + let r_cord = transcript.draw().unwrap(); + + let p_poly_coef = vec![claims_sum_vec.0, claims_sum_vec.1]; + let q_poly_coef = vec![claims_sum_vec.2, claims_sum_vec.3]; + + let p_poly = MultiLinear::new(p_poly_coef); + let q_poly = MultiLinear::new(q_poly_coef); + let p_eval = p_poly.evaluate(&[r_cord]); + let q_eval = q_poly.evaluate(&[r_cord]); + + let mut reduced_claim = (p_eval, q_eval); + + rand.push(r_cord); + for (num_rounds, i) in (0..num_layers).enumerate() { + let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i] + .verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript); + + let claims_sum_p1 = &self.proof[i].claims_sum_p1; + let claims_sum_p0 = &self.proof[i].claims_sum_p0; + let claims_sum_q1 = &self.proof[i].claims_sum_q1; + let claims_sum_q0 = &self.proof[i].claims_sum_q0; + + let data = vec![ + claims_sum_p1.clone(), + claims_sum_p0.clone(), + claims_sum_q1.clone(), + claims_sum_q0.clone(), + ]; + transcript.reseed(H::hash_elements(&data)); + + assert_eq!(rand.len(), rand_sumcheck.len()); + + let eq: E = (0..rand.len()) + .map(|i| { + rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i]) + }) + .fold(E::ONE, |acc, term| acc * term); + + let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0 + + *claims_sum_p0 * *claims_sum_q1 + + r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0) + * eq; + + assert_eq!(claim_expected, claim_last); + + // Produce a random challenge to condense claims into a single claim + let r_layer = transcript.draw().unwrap(); + + reduced_claim = ( + *claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1), + *claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1), + ); + + // Collect the randomness' used for the current layer in order to construct the random + // point where the input multilinear polynomials were evaluated. + let mut ext = rand_sumcheck; + ext.push(r_layer); + rand = ext; + } + (reduced_claim, rand) + } + + pub fn verify_virtual_bus< + C: RandomCoin, + H: ElementHasher, + >( + &self, + composition_polys: Vec>>>, + final_layer_proof: super::sumcheck::FullProof, + claims_sum_vec: &(E, E, E, E), + transcript: &mut C, + ) -> (FinalEvaluationClaim, Vec) { + let num_layers = self.proof.len() as usize; + let mut rand: Vec = Vec::new(); + + // Check that a/b + d/e is equal to 0 + assert_ne!(claims_sum_vec.2, E::ZERO); + assert_ne!(claims_sum_vec.3, E::ZERO); + assert_eq!( + claims_sum_vec.0 * claims_sum_vec.3 + claims_sum_vec.1 * claims_sum_vec.2, + E::ZERO + ); + + let data = vec![claims_sum_vec.0, claims_sum_vec.1, claims_sum_vec.2, claims_sum_vec.3]; + transcript.reseed(H::hash_elements(&data)); + + let r_cord = transcript.draw().unwrap(); + + let p_poly_coef = vec![claims_sum_vec.0, claims_sum_vec.1]; + let q_poly_coef = vec![claims_sum_vec.2, claims_sum_vec.3]; + + let p_poly = MultiLinear::new(p_poly_coef); + let q_poly = MultiLinear::new(q_poly_coef); + let p_eval = p_poly.evaluate(&[r_cord]); + let q_eval = q_poly.evaluate(&[r_cord]); + + let mut reduced_claim = (p_eval, q_eval); + + // I) Verify all GKR layers but for the last one counting backwards. + rand.push(r_cord); + for (num_rounds, i) in (0..num_layers).enumerate() { + let ((claim_last, rand_sumcheck), r_two_sumchecks) = self.proof[i] + .verify_sum_check_before_last::<_, _>(reduced_claim, num_rounds + 1, transcript); + + let claims_sum_p1 = &self.proof[i].claims_sum_p1; + let claims_sum_p0 = &self.proof[i].claims_sum_p0; + let claims_sum_q1 = &self.proof[i].claims_sum_q1; + let claims_sum_q0 = &self.proof[i].claims_sum_q0; + + let data = vec![ + claims_sum_p1.clone(), + claims_sum_p0.clone(), + claims_sum_q1.clone(), + claims_sum_q0.clone(), + ]; + transcript.reseed(H::hash_elements(&data)); + + assert_eq!(rand.len(), rand_sumcheck.len()); + + let eq: E = (0..rand.len()) + .map(|i| { + rand[i] * rand_sumcheck[i] + (E::ONE - rand[i]) * (E::ONE - rand_sumcheck[i]) + }) + .fold(E::ONE, |acc, term| acc * term); + + let claim_expected: E = (*claims_sum_p1 * *claims_sum_q0 + + *claims_sum_p0 * *claims_sum_q1 + + r_two_sumchecks * *claims_sum_q1 * *claims_sum_q0) + * eq; + + assert_eq!(claim_expected, claim_last); + + // Produce a random challenge to condense claims into a single claim + let r_layer = transcript.draw().unwrap(); + + reduced_claim = ( + *claims_sum_p1 + r_layer * (*claims_sum_p0 - *claims_sum_p1), + *claims_sum_q1 + r_layer * (*claims_sum_q0 - *claims_sum_q1), + ); + + let mut ext = rand_sumcheck; + ext.push(r_layer); + rand = ext; + } + + // II) Verify the final GKR layer counting backwards. + + // Absorb the claims + let data = vec![reduced_claim.0, reduced_claim.1]; + transcript.reseed(H::hash_elements(&data)); + + // Squeeze challenge to reduce two sumchecks to one + let r_sum_check = transcript.draw().unwrap(); + let reduced_claim = reduced_claim.0 + reduced_claim.1 * r_sum_check; + + let gkr_final_composed_ml = gkr_composition_from_composition_polys( + &composition_polys, + r_sum_check, + 1 << (num_layers + 1), + ); + + // TODO: refactor + let composed_ml_oracle = { + let left_num_oracle = MultiLinearOracle { id: 0 }; + let right_num_oracle = MultiLinearOracle { id: 1 }; + let left_denom_oracle = MultiLinearOracle { id: 2 }; + let right_denom_oracle = MultiLinearOracle { id: 3 }; + let eq_oracle = MultiLinearOracle { id: 4 }; + ComposedMultiLinearsOracle { + composer: (Arc::new(gkr_final_composed_ml.clone())), + multi_linears: vec![ + eq_oracle, + left_num_oracle, + right_num_oracle, + left_denom_oracle, + right_denom_oracle, + ], + } + }; + + let claim = Claim { + sum_value: reduced_claim, + polynomial: composed_ml_oracle.clone(), + }; + + let final_eval_claim = sum_check_verify(&claim, final_layer_proof, transcript); + + (final_eval_claim, rand) + } +} + +fn sum_check_prover_gkr_before_last< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + claim: (E, E), + num_rounds: usize, + ml_polys: ( + &mut MultiLinear, + &mut MultiLinear, + &mut MultiLinear, + &mut MultiLinear, + &mut MultiLinear, + ), + comb_func: impl Fn(&E, &E, &E, &E, &E, &E) -> E, + transcript: &mut C, +) -> (SumcheckInstanceProof, Vec, (E, E, E, E, E)) { + // Absorb the claims + let data = vec![claim.0, claim.1]; + transcript.reseed(H::hash_elements(&data)); + + // Squeeze challenge to reduce two sumchecks to one + let r_sum_check = transcript.draw().unwrap(); + + let (poly_a, poly_b, poly_c, poly_d, poly_x) = ml_polys; + + let mut e = claim.0 + claim.1 * r_sum_check; + + let mut r: Vec = Vec::new(); + let mut round_proofs: Vec> = Vec::new(); + + for _j in 0..num_rounds { + let evals: (E, E, E) = { + let mut eval_point_0 = E::ZERO; + let mut eval_point_2 = E::ZERO; + let mut eval_point_3 = E::ZERO; + + let len = poly_a.len() / 2; + for i in 0..len { + // The interpolation formula for a linear function is: + // z * A(x) + (1 - z) * A (y) + // z * A(1) + (1 - z) * A(0) + + // eval at z = 0: A(1) + eval_point_0 += comb_func( + &poly_a[i << 1], + &poly_b[i << 1], + &poly_c[i << 1], + &poly_d[i << 1], + &poly_x[i << 1], + &r_sum_check, + ); + + let poly_a_u = poly_a[(i << 1) + 1]; + let poly_a_v = poly_a[i << 1]; + let poly_b_u = poly_b[(i << 1) + 1]; + let poly_b_v = poly_b[i << 1]; + let poly_c_u = poly_c[(i << 1) + 1]; + let poly_c_v = poly_c[i << 1]; + let poly_d_u = poly_d[(i << 1) + 1]; + let poly_d_v = poly_d[i << 1]; + let poly_x_u = poly_x[(i << 1) + 1]; + let poly_x_v = poly_x[i << 1]; + + // eval at z = 2: 2 * A(1) - A(0) + let poly_a_extrapolated_point = poly_a_u + poly_a_u - poly_a_v; + let poly_b_extrapolated_point = poly_b_u + poly_b_u - poly_b_v; + let poly_c_extrapolated_point = poly_c_u + poly_c_u - poly_c_v; + let poly_d_extrapolated_point = poly_d_u + poly_d_u - poly_d_v; + let poly_x_extrapolated_point = poly_x_u + poly_x_u - poly_x_v; + eval_point_2 += comb_func( + &poly_a_extrapolated_point, + &poly_b_extrapolated_point, + &poly_c_extrapolated_point, + &poly_d_extrapolated_point, + &poly_x_extrapolated_point, + &r_sum_check, + ); + + // eval at z = 3: 3 * A(1) - 2 * A(0) = 2 * A(1) - A(0) + A(1) - A(0) + // hence we can compute the evaluation at z + 1 from that of z for z > 1 + let poly_a_extrapolated_point = poly_a_extrapolated_point + poly_a_u - poly_a_v; + let poly_b_extrapolated_point = poly_b_extrapolated_point + poly_b_u - poly_b_v; + let poly_c_extrapolated_point = poly_c_extrapolated_point + poly_c_u - poly_c_v; + let poly_d_extrapolated_point = poly_d_extrapolated_point + poly_d_u - poly_d_v; + let poly_x_extrapolated_point = poly_x_extrapolated_point + poly_x_u - poly_x_v; + + eval_point_3 += comb_func( + &poly_a_extrapolated_point, + &poly_b_extrapolated_point, + &poly_c_extrapolated_point, + &poly_d_extrapolated_point, + &poly_x_extrapolated_point, + &r_sum_check, + ); + } + + (eval_point_0, eval_point_2, eval_point_3) + }; + + let eval_0 = evals.0; + let eval_2 = evals.1; + let eval_3 = evals.2; + + let evals = vec![e - eval_0, eval_2, eval_3]; + let compressed_poly = SumCheckRoundProof { poly_evals: evals }; + + // append the prover's message to the transcript + transcript.reseed(H::hash_elements(&compressed_poly.poly_evals)); + + // derive the verifier's challenge for the next round + let r_j = transcript.draw().unwrap(); + r.push(r_j); + + poly_a.bind_assign(r_j); + poly_b.bind_assign(r_j); + poly_c.bind_assign(r_j); + poly_d.bind_assign(r_j); + + poly_x.bind_assign(r_j); + + e = compressed_poly.evaluate(e, r_j); + + round_proofs.push(compressed_poly); + } + let claims_sum = (poly_a[0], poly_b[0], poly_c[0], poly_d[0], poly_x[0]); + + (SumcheckInstanceProof { round_proofs }, r, claims_sum) +} + +#[cfg(test)] +mod sum_circuit_tests { + use crate::rand::RpoRandomCoin; + + use super::*; + use rand::Rng; + use rand_utils::rand_value; + use BaseElement as Felt; + + /// The following tests the fractional sum circuit to check that \sum_{i = 0}^{log(m)-1} m / 2^{i} = 2 * (m - 1) + #[test] + fn sum_circuit_example() { + let n = 4; // n := log(m) + let mut inp: Vec = (0..n).map(|_| Felt::from(1_u64 << n)).collect(); + let inp_: Vec = (0..n).map(|i| Felt::from(1_u64 << i)).collect(); + inp.extend(inp_.iter()); + + let summation = MultiLinear::new(inp); + + let expected_output = Felt::from(2 * ((1_u64 << n) - 1)); + + let mut circuit = FractionalSumCircuit::new(&summation); + + let seed = [BaseElement::ZERO; 4]; + let mut transcript = RpoRandomCoin::new(seed.into()); + + let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript); + + let (p1, q1) = circuit.evaluate(Felt::from(1_u8)); + let (p0, q0) = circuit.evaluate(Felt::from(0_u8)); + assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); + + let seed = [BaseElement::ZERO; 4]; + let mut transcript = RpoRandomCoin::new(seed.into()); + let claims = (p0, p1, q0, q1); + proof.verify(&claims, &mut transcript); + } + + // Test the fractional sum GKR in the context of LogUp. + #[test] + fn log_up() { + use rand::distributions::Slice; + + let n: usize = 16; + let num_w: usize = 31; // This should be of the form 2^k - 1 + let rng = rand::thread_rng(); + + let t_table: Vec = (0..(1 << n)).collect(); + let mut m_table: Vec = (0..(1 << n)).map(|_| 0).collect(); + + let t_table_slice = Slice::new(&t_table).unwrap(); + + // Construct the witness columns. Uses sampling with replacement in order to have multiplicities + // different from 1. + let mut w_tables = Vec::new(); + for _ in 0..num_w { + let wi_table: Vec = + rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect(); + + // Construct the multiplicities + wi_table.iter().for_each(|w| { + m_table[*w as usize] += 1; + }); + w_tables.push(wi_table) + } + + // The numerators + let mut p: Vec = m_table.iter().map(|m| Felt::from(*m as u32)).collect(); + p.extend((0..(num_w * (1 << n))).map(|_| Felt::from(1_u32)).collect::>()); + + // Sample the challenge alpha to construct the denominators. + let alpha = rand_value(); + + // Construct the denominators + let mut q: Vec = t_table.iter().map(|t| Felt::from(*t) - alpha).collect(); + for w_table in w_tables { + q.extend(w_table.iter().map(|w| alpha - Felt::from(*w)).collect::>()); + } + + // Build the input to the fractional sum GKR circuit + p.extend(q); + let input = p; + + let summation = MultiLinear::new(input); + + let expected_output = Felt::from(0_u8); + + let mut circuit = FractionalSumCircuit::new(&summation); + + let seed = [BaseElement::ZERO; 4]; + let mut transcript = RpoRandomCoin::new(seed.into()); + + let (proof, _evals, _) = CircuitProof::prove(&mut circuit, &mut transcript); + + let (p1, q1) = circuit.evaluate(Felt::from(1_u8)); + let (p0, q0) = circuit.evaluate(Felt::from(0_u8)); + assert_eq!(expected_output, (p1 * q0 + q1 * p0) / (q1 * q0)); // This check should be part of verification + + let seed = [BaseElement::ZERO; 4]; + let mut transcript = RpoRandomCoin::new(seed.into()); + let claims = (p0, p1, q0, q1); + proof.verify(&claims, &mut transcript); + } +} diff --git a/src/gkr/mod.rs b/src/gkr/mod.rs new file mode 100644 index 0000000..fd717e6 --- /dev/null +++ b/src/gkr/mod.rs @@ -0,0 +1,7 @@ +#![allow(unused_imports)] +#![allow(dead_code)] + +mod sumcheck; +mod multivariate; +mod utils; +mod circuit; \ No newline at end of file diff --git a/src/gkr/multivariate/eq_poly.rs b/src/gkr/multivariate/eq_poly.rs new file mode 100644 index 0000000..85d7b8e --- /dev/null +++ b/src/gkr/multivariate/eq_poly.rs @@ -0,0 +1,34 @@ +use super::FieldElement; + +pub struct EqPolynomial { + r: Vec, +} + +impl EqPolynomial { + pub fn new(r: Vec) -> Self { + EqPolynomial { r } + } + + pub fn evaluate(&self, rho: &[E]) -> E { + assert_eq!(self.r.len(), rho.len()); + (0..rho.len()) + .map(|i| self.r[i] * rho[i] + (E::ONE - self.r[i]) * (E::ONE - rho[i])) + .fold(E::ONE, |acc, term| acc * term) + } + + pub fn evaluations(&self) -> Vec { + let nu = self.r.len(); + + let mut evals: Vec = vec![E::ONE; 1 << nu]; + let mut size = 1; + for j in 0..nu { + size *= 2; + for i in (0..size).rev().step_by(2) { + let scalar = evals[i / 2]; + evals[i] = scalar * self.r[j]; + evals[i - 1] = scalar - evals[i]; + } + } + evals + } +} diff --git a/src/gkr/multivariate/mod.rs b/src/gkr/multivariate/mod.rs new file mode 100644 index 0000000..83f7c66 --- /dev/null +++ b/src/gkr/multivariate/mod.rs @@ -0,0 +1,543 @@ +use core::ops::Index; + +use alloc::sync::Arc; +use winter_math::{fields::f64::BaseElement, log2, FieldElement, StarkField}; + +mod eq_poly; +pub use eq_poly::EqPolynomial; + +#[derive(Clone, Debug)] +pub struct MultiLinear { + pub num_variables: usize, + pub evaluations: Vec, +} + +impl MultiLinear { + pub fn new(values: Vec) -> Self { + Self { + num_variables: log2(values.len()) as usize, + evaluations: values, + } + } + + pub fn from_values(values: &[E]) -> Self { + Self { + num_variables: log2(values.len()) as usize, + evaluations: values.to_owned(), + } + } + + pub fn num_variables(&self) -> usize { + self.num_variables + } + + pub fn evaluations(&self) -> &[E] { + &self.evaluations + } + + pub fn len(&self) -> usize { + self.evaluations.len() + } + + pub fn evaluate(&self, query: &[E]) -> E { + let tensored_query = tensorize(query); + inner_product(&self.evaluations, &tensored_query) + } + + pub fn bind(&self, round_challenge: E) -> Self { + let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)]; + for i in 0..(1 << (self.num_variables() - 1)) { + result[i] = self.evaluations[i << 1] + + round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]); + } + Self::from_values(&result) + } + + pub fn bind_assign(&mut self, round_challenge: E) { + let mut result = vec![E::ZERO; 1 << (self.num_variables() - 1)]; + for i in 0..(1 << (self.num_variables() - 1)) { + result[i] = self.evaluations[i << 1] + + round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]); + } + *self = Self::from_values(&result); + } + + pub fn split(&self, at: usize) -> (Self, Self) { + assert!(at < self.len()); + ( + Self::new(self.evaluations[..at].to_vec()), + Self::new(self.evaluations[at..2 * at].to_vec()), + ) + } + + pub fn extend(&mut self, other: &MultiLinear) { + let other_vec = other.evaluations.to_vec(); + assert_eq!(other_vec.len(), self.len()); + self.evaluations.extend(other_vec); + self.num_variables += 1; + } +} + +impl Index for MultiLinear { + type Output = E; + + fn index(&self, index: usize) -> &E { + &(self.evaluations[index]) + } +} + +/// A multi-variate polynomial for composing individual multi-linear polynomials +pub trait CompositionPolynomial: Sync + Send { + /// The number of variables when interpreted as a multi-variate polynomial. + fn num_variables(&self) -> usize; + + /// Maximum degree in all variables. + fn max_degree(&self) -> usize; + + /// Given a query, of length equal the number of variables, evaluate [Self] at this query. + fn evaluate(&self, query: &[E]) -> E; +} + +pub struct ComposedMultiLinears { + pub composer: Arc>, + pub multi_linears: Vec>, +} + +impl ComposedMultiLinears { + pub fn new( + composer: Arc>, + multi_linears: Vec>, + ) -> Self { + Self { composer, multi_linears } + } + + pub fn num_ml(&self) -> usize { + self.multi_linears.len() + } + + pub fn num_variables(&self) -> usize { + self.composer.num_variables() + } + + pub fn num_variables_ml(&self) -> usize { + self.multi_linears[0].num_variables + } + + pub fn degree(&self) -> usize { + self.composer.max_degree() + } + + pub fn bind(&self, round_challenge: E) -> ComposedMultiLinears { + let result: Vec> = + self.multi_linears.iter().map(|f| f.bind(round_challenge)).collect(); + + Self { + composer: self.composer.clone(), + multi_linears: result, + } + } +} + +#[derive(Clone)] +pub struct ComposedMultiLinearsOracle { + pub composer: Arc>, + pub multi_linears: Vec, +} + +#[derive(Debug, Clone)] +pub struct MultiLinearOracle { + pub id: usize, +} + +// Composition polynomials + +pub struct IdentityComposition { + num_variables: usize, +} + +impl IdentityComposition { + pub fn new() -> Self { + Self { num_variables: 1 } + } +} + +impl CompositionPolynomial for IdentityComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + self.num_variables + } + + fn max_degree(&self) -> usize { + self.num_variables + } + + fn evaluate(&self, query: &[E]) -> E { + assert_eq!(query.len(), 1); + query[0] + } +} + +pub struct ProjectionComposition { + coordinate: usize, +} + +impl ProjectionComposition { + pub fn new(coordinate: usize) -> Self { + Self { coordinate } + } +} + +impl CompositionPolynomial for ProjectionComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + 1 + } + + fn max_degree(&self) -> usize { + 1 + } + + fn evaluate(&self, query: &[E]) -> E { + query[self.coordinate] + } +} + +pub struct LogUpDenominatorTableComposition +where + E: FieldElement, +{ + projection_coordinate: usize, + alpha: E, +} + +impl LogUpDenominatorTableComposition +where + E: FieldElement, +{ + pub fn new(projection_coordinate: usize, alpha: E) -> Self { + Self { projection_coordinate, alpha } + } +} + +impl CompositionPolynomial for LogUpDenominatorTableComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + 1 + } + + fn max_degree(&self) -> usize { + 1 + } + + fn evaluate(&self, query: &[E]) -> E { + query[self.projection_coordinate] + self.alpha + } +} + +pub struct LogUpDenominatorWitnessComposition +where + E: FieldElement, +{ + projection_coordinate: usize, + alpha: E, +} + +impl LogUpDenominatorWitnessComposition +where + E: FieldElement, +{ + pub fn new(projection_coordinate: usize, alpha: E) -> Self { + Self { projection_coordinate, alpha } + } +} + +impl CompositionPolynomial for LogUpDenominatorWitnessComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + 1 + } + + fn max_degree(&self) -> usize { + 1 + } + + fn evaluate(&self, query: &[E]) -> E { + -(query[self.projection_coordinate] + self.alpha) + } +} + +pub struct ProductComposition { + num_variables: usize, +} + +impl ProductComposition { + pub fn new(num_variables: usize) -> Self { + Self { num_variables } + } +} + +impl CompositionPolynomial for ProductComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + self.num_variables + } + + fn max_degree(&self) -> usize { + self.num_variables + } + + fn evaluate(&self, query: &[E]) -> E { + query.iter().fold(E::ONE, |acc, x| acc * *x) + } +} + +pub struct SumComposition { + num_variables: usize, +} + +impl SumComposition { + pub fn new(num_variables: usize) -> Self { + Self { num_variables } + } +} + +impl CompositionPolynomial for SumComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + self.num_variables + } + + fn max_degree(&self) -> usize { + self.num_variables + } + + fn evaluate(&self, query: &[E]) -> E { + query.iter().fold(E::ZERO, |acc, x| acc + *x) + } +} + +pub struct GkrCompositionVanilla +where + E: FieldElement, +{ + num_variables_ml: usize, + num_variables_merge: usize, + combining_randomness: E, + gkr_randomness: Vec, +} + +impl GkrCompositionVanilla +where + E: FieldElement, +{ + pub fn new( + num_variables_ml: usize, + num_variables_merge: usize, + combining_randomness: E, + gkr_randomness: Vec, + ) -> Self { + Self { + num_variables_ml, + num_variables_merge, + combining_randomness, + gkr_randomness, + } + } +} + +impl CompositionPolynomial for GkrCompositionVanilla +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + self.num_variables_ml // + TODO + } + + fn max_degree(&self) -> usize { + self.num_variables_ml //TODO + } + + fn evaluate(&self, query: &[E]) -> E { + let eval_left_numerator = query[0]; + let eval_right_numerator = query[1]; + let eval_left_denominator = query[2]; + let eval_right_denominator = query[3]; + let eq_eval = query[4]; + + eq_eval + * ((eval_left_numerator * eval_right_denominator + + eval_right_numerator * eval_left_denominator) + + eval_left_denominator * eval_right_denominator * self.combining_randomness) + } +} + +#[derive(Clone)] +pub struct GkrComposition +where + E: FieldElement, +{ + pub num_variables_ml: usize, + pub combining_randomness: E, + + eq_composer: Arc>, + right_numerator_composer: Vec>>, + left_numerator_composer: Vec>>, + right_denominator_composer: Vec>>, + left_denominator_composer: Vec>>, +} + +impl GkrComposition +where + E: FieldElement, +{ + pub fn new( + num_variables_ml: usize, + combining_randomness: E, + eq_composer: Arc>, + right_numerator_composer: Vec>>, + left_numerator_composer: Vec>>, + right_denominator_composer: Vec>>, + left_denominator_composer: Vec>>, + ) -> Self { + Self { + num_variables_ml, + combining_randomness, + eq_composer, + right_numerator_composer, + left_numerator_composer, + right_denominator_composer, + left_denominator_composer, + } + } +} + +impl CompositionPolynomial for GkrComposition +where + E: FieldElement, +{ + fn num_variables(&self) -> usize { + self.num_variables_ml // + TODO + } + + fn max_degree(&self) -> usize { + 3 // TODO + } + + fn evaluate(&self, query: &[E]) -> E { + let eval_right_numerator = self.right_numerator_composer[0].evaluate(query); + let eval_left_numerator = self.left_numerator_composer[0].evaluate(query); + let eval_right_denominator = self.right_denominator_composer[0].evaluate(query); + let eval_left_denominator = self.left_denominator_composer[0].evaluate(query); + let eq_eval = self.eq_composer.evaluate(query); + + let res = eq_eval + * ((eval_left_numerator * eval_right_denominator + + eval_right_numerator * eval_left_denominator) + + eval_left_denominator * eval_right_denominator * self.combining_randomness); + res + } +} + +/// Generates a composed ML polynomial for the initial GKR layer from a vector of composition +/// polynomials. +/// The composition polynomials are divided into LeftNumerator, RightNumerator, LeftDenominator +/// and RightDenominator. +/// TODO: Generalize this to the case where each numerator/denominator contains more than one +/// composition polynomial i.e., a merged composed ML polynomial. +pub fn gkr_composition_from_composition_polys< + E: FieldElement + 'static, +>( + composition_polys: &Vec>>>, + combining_randomness: E, + num_variables: usize, +) -> GkrComposition { + let eq_composer = Arc::new(ProjectionComposition::new(4)); + let left_numerator = composition_polys[0].to_owned(); + let right_numerator = composition_polys[1].to_owned(); + let left_denominator = composition_polys[2].to_owned(); + let right_denominator = composition_polys[3].to_owned(); + GkrComposition::new( + num_variables, + combining_randomness, + eq_composer, + right_numerator, + left_numerator, + right_denominator, + left_denominator, + ) +} + +/// Generates a plain oracle for the sum-check protocol except the final one. +pub fn gen_plain_gkr_oracle + 'static>( + num_rounds: usize, + r_sum_check: E, +) -> ComposedMultiLinearsOracle { + let gkr_composer = Arc::new(GkrCompositionVanilla::new(num_rounds, 0, r_sum_check, vec![])); + + let ml_oracles = vec![ + MultiLinearOracle { id: 0 }, + MultiLinearOracle { id: 1 }, + MultiLinearOracle { id: 2 }, + MultiLinearOracle { id: 3 }, + MultiLinearOracle { id: 4 }, + ]; + + let oracle = ComposedMultiLinearsOracle { + composer: gkr_composer, + multi_linears: ml_oracles, + }; + oracle +} + +fn to_index>(index: &[E]) -> usize { + let res = index.iter().fold(E::ZERO, |acc, term| acc * E::ONE.double() + (*term)); + let res = res.base_element(0); + res.as_int() as usize +} + +fn inner_product(evaluations: &[E], tensored_query: &[E]) -> E { + assert_eq!(evaluations.len(), tensored_query.len()); + evaluations + .iter() + .zip(tensored_query.iter()) + .fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i) +} + +pub fn tensorize(query: &[E]) -> Vec { + let nu = query.len(); + let n = 1 << nu; + + (0..n).map(|i| lagrange_basis_eval(query, i)).collect() +} + +fn lagrange_basis_eval(query: &[E], i: usize) -> E { + query + .iter() + .enumerate() + .map(|(j, x_j)| if i & (1 << j) == 0 { E::ONE - *x_j } else { *x_j }) + .fold(E::ONE, |acc, v| acc * v) +} + +pub fn compute_claim(poly: &ComposedMultiLinears) -> E { + let cube_size = 1 << poly.num_variables_ml(); + let mut res = E::ZERO; + + for i in 0..cube_size { + let eval_point: Vec = + poly.multi_linears.iter().map(|poly| poly.evaluations[i]).collect(); + res += poly.composer.evaluate(&eval_point); + } + res +} diff --git a/src/gkr/sumcheck/mod.rs b/src/gkr/sumcheck/mod.rs new file mode 100644 index 0000000..80b378b --- /dev/null +++ b/src/gkr/sumcheck/mod.rs @@ -0,0 +1,108 @@ +use super::{ + multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle}, + utils::{barycentric_weights, evaluate_barycentric}, +}; +use winter_math::FieldElement; + +mod prover; +pub use prover::sum_check_prove; +mod verifier; +pub use verifier::{sum_check_verify, sum_check_verify_and_reduce}; +mod tests; + +#[derive(Debug, Clone)] +pub struct RoundProof { + pub poly_evals: Vec, +} + +impl RoundProof { + pub fn to_evals(&self, claim: E) -> Vec { + let mut result = vec![]; + + // s(0) + s(1) = claim + let c0 = claim - self.poly_evals[0]; + + result.push(c0); + result.extend_from_slice(&self.poly_evals); + result + } + + // TODO: refactor once we move to coefficient form + pub(crate) fn evaluate(&self, claim: E, r: E) -> E { + let poly_evals = self.to_evals(claim); + + let points: Vec = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect(); + let evalss: Vec<(E, E)> = + points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect(); + let weights = barycentric_weights(&evalss); + let new_claim = evaluate_barycentric(&evalss, r, &weights); + new_claim + } +} + +#[derive(Debug, Clone)] +pub struct PartialProof { + pub round_proofs: Vec>, +} + +#[derive(Clone)] +pub struct FinalEvaluationClaim { + pub evaluation_point: Vec, + pub claimed_evaluation: E, + pub polynomial: ComposedMultiLinearsOracle, +} + +#[derive(Clone)] +pub struct FullProof { + pub sum_check_proof: PartialProof, + pub final_evaluation_claim: FinalEvaluationClaim, +} + +pub struct Claim { + pub sum_value: E, + pub polynomial: ComposedMultiLinearsOracle, +} + +#[derive(Debug)] +pub struct RoundClaim { + pub partial_eval_point: Vec, + pub current_claim: E, +} + +pub struct RoundOutput { + proof: PartialProof, + witness: Witness, +} + +impl From> for RoundClaim { + fn from(value: Claim) -> Self { + Self { + partial_eval_point: vec![], + current_claim: value.sum_value, + } + } +} + +pub struct Witness { + pub(crate) polynomial: ComposedMultiLinears, +} + +pub fn reduce_claim( + current_poly: RoundProof, + current_round_claim: RoundClaim, + round_challenge: E, +) -> RoundClaim { + let poly_evals = current_poly.to_evals(current_round_claim.current_claim); + let points: Vec = (0..poly_evals.len()).map(|i| E::from(i as u8)).collect(); + let evalss: Vec<(E, E)> = points.iter().zip(poly_evals.iter()).map(|(x, y)| (*x, *y)).collect(); + let weights = barycentric_weights(&evalss); + let new_claim = evaluate_barycentric(&evalss, round_challenge, &weights); + + let mut new_partial_eval_point = current_round_claim.partial_eval_point; + new_partial_eval_point.push(round_challenge); + + RoundClaim { + partial_eval_point: new_partial_eval_point, + current_claim: new_claim, + } +} diff --git a/src/gkr/sumcheck/prover.rs b/src/gkr/sumcheck/prover.rs new file mode 100644 index 0000000..a25be0e --- /dev/null +++ b/src/gkr/sumcheck/prover.rs @@ -0,0 +1,109 @@ +use super::{Claim, FullProof, RoundProof, Witness}; +use crate::gkr::{ + multivariate::{ComposedMultiLinears, ComposedMultiLinearsOracle}, + sumcheck::{reduce_claim, FinalEvaluationClaim, PartialProof, RoundClaim, RoundOutput}, +}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use winter_crypto::{ElementHasher, RandomCoin}; +use winter_math::{fields::f64::BaseElement, FieldElement}; + +pub fn sum_check_prove< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + claim: &Claim, + oracle: ComposedMultiLinearsOracle, + witness: Witness, + coin: &mut C, +) -> FullProof { + // Setup first round + let mut prev_claim = RoundClaim { + partial_eval_point: vec![], + current_claim: claim.sum_value.clone(), + }; + let prev_proof = PartialProof { round_proofs: vec![] }; + let num_vars = witness.polynomial.num_variables_ml(); + let prev_output = RoundOutput { proof: prev_proof, witness }; + + let mut output = sumcheck_round(prev_output); + let poly_evals = &output.proof.round_proofs[0].poly_evals; + coin.reseed(H::hash_elements(&poly_evals)); + + for i in 1..num_vars { + let round_challenge = coin.draw().unwrap(); + let new_claim = reduce_claim( + output.proof.round_proofs.last().unwrap().clone(), + prev_claim, + round_challenge, + ); + output.witness.polynomial = output.witness.polynomial.bind(round_challenge); + + output = sumcheck_round(output); + prev_claim = new_claim; + + let poly_evals = &output.proof.round_proofs[i].poly_evals; + coin.reseed(H::hash_elements(&poly_evals)); + } + + let round_challenge = coin.draw().unwrap(); + let RoundClaim { partial_eval_point, current_claim } = reduce_claim( + output.proof.round_proofs.last().unwrap().clone(), + prev_claim, + round_challenge, + ); + let final_eval_claim = FinalEvaluationClaim { + evaluation_point: partial_eval_point, + claimed_evaluation: current_claim, + polynomial: oracle, + }; + + FullProof { + sum_check_proof: output.proof, + final_evaluation_claim: final_eval_claim, + } +} + +fn sumcheck_round(prev_proof: RoundOutput) -> RoundOutput { + let RoundOutput { mut proof, witness } = prev_proof; + + let polynomial = witness.polynomial; + let num_ml = polynomial.num_ml(); + let num_vars = polynomial.num_variables_ml(); + let num_rounds = num_vars - 1; + + let mut evals_zero = vec![E::ZERO; num_ml]; + let mut evals_one = vec![E::ZERO; num_ml]; + let mut deltas = vec![E::ZERO; num_ml]; + let mut evals_x = vec![E::ZERO; num_ml]; + + let total_evals = (0..1 << num_rounds).into_iter().map(|i| { + for (j, ml) in polynomial.multi_linears.iter().enumerate() { + evals_zero[j] = ml.evaluations[(i << 1) as usize]; + evals_one[j] = ml.evaluations[(i << 1) + 1]; + } + let mut total_evals = vec![E::ZERO; polynomial.degree()]; + total_evals[0] = polynomial.composer.evaluate(&evals_one); + evals_zero + .iter() + .zip(evals_one.iter().zip(deltas.iter_mut().zip(evals_x.iter_mut()))) + .for_each(|(a0, (a1, (delta, evx)))| { + *delta = *a1 - *a0; + *evx = *a1; + }); + total_evals.iter_mut().skip(1).for_each(|e| { + evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + *e = polynomial.composer.evaluate(&evals_x); + }); + total_evals + }); + let evaluations = total_evals.fold(vec![E::ZERO; polynomial.degree()], |mut acc, evals| { + acc.iter_mut().zip(evals.iter()).for_each(|(a, ev)| *a += *ev); + acc + }); + let proof_update = RoundProof { poly_evals: evaluations }; + proof.round_proofs.push(proof_update); + RoundOutput { proof, witness: Witness { polynomial } } +} diff --git a/src/gkr/sumcheck/tests.rs b/src/gkr/sumcheck/tests.rs new file mode 100644 index 0000000..2ed161d --- /dev/null +++ b/src/gkr/sumcheck/tests.rs @@ -0,0 +1,201 @@ +use alloc::sync::Arc; +use rand::{distributions::Uniform, SeedableRng}; +use winter_crypto::RandomCoin; +use winter_math::{fields::f64::BaseElement, FieldElement}; + +use crate::{ + gkr::{ + circuit::{CircuitProof, FractionalSumCircuit}, + multivariate::{ + compute_claim, gkr_composition_from_composition_polys, ComposedMultiLinears, + ComposedMultiLinearsOracle, CompositionPolynomial, EqPolynomial, GkrComposition, + GkrCompositionVanilla, LogUpDenominatorTableComposition, + LogUpDenominatorWitnessComposition, MultiLinear, MultiLinearOracle, + ProjectionComposition, SumComposition, + }, + sumcheck::{ + prover::sum_check_prove, verifier::sum_check_verify, Claim, FinalEvaluationClaim, + FullProof, Witness, + }, + }, + hash::rpo::Rpo256, + rand::RpoRandomCoin, +}; + +#[test] +fn gkr_workflow() { + // generate the data witness for the LogUp argument + let mut mls = generate_logup_witness::(3); + + // the is sampled after receiving the main trace commitment + let alpha = rand_utils::rand_value(); + + // the composition polynomials defining the numerators/denominators + let composition_polys: Vec>>> = vec![ + // left num + vec![Arc::new(ProjectionComposition::new(0))], + // right num + vec![Arc::new(ProjectionComposition::new(1))], + // left den + vec![Arc::new(LogUpDenominatorTableComposition::new(2, alpha))], + // right den + vec![Arc::new(LogUpDenominatorWitnessComposition::new(3, alpha))], + ]; + + // run the GKR prover to obtain: + // 1. The fractional sum circuit output. + // 2. GKR proofs up to the last circuit layer counting backwards. + // 3. GKR proof (i.e., a sum-check proof) for the last circuit layer counting backwards. + let seed = [BaseElement::ZERO; 4]; + let mut transcript = RpoRandomCoin::new(seed.into()); + let (circuit_outputs, gkr_before_last_proof, final_layer_proof) = + CircuitProof::prove_virtual_bus(composition_polys.clone(), &mut mls, &mut transcript); + + let seed = [BaseElement::ZERO; 4]; + let mut transcript = RpoRandomCoin::new(seed.into()); + + // run the GKR verifier to obtain: + // 1. A final evaluation claim. + // 2. Randomness defining the Lagrange kernel in the final sum-check protocol. Note that this + // Lagrange kernel is different from the one used by the STARK (outer) prover to open the MLs + // at the evaluation point. + let circuit_outputs = + (circuit_outputs[0], circuit_outputs[1], circuit_outputs[2], circuit_outputs[3]); + let (final_eval_claim, gkr_lagrange_kernel_rand) = gkr_before_last_proof.verify_virtual_bus( + composition_polys.clone(), + final_layer_proof, + &circuit_outputs, + &mut transcript, + ); + + // the final verification step is composed of: + // 1. Querying the oracles for the openings at the evaluation point. This will be done by the + // (outer) STARK prover using: + // a. The Lagrange kernel (auxiliary) column at the evaluation point. + // b. An extra (auxiliary) column to compute an inner product between two vectors. The first + // being the Lagrange kernel and the second being (\sum_{j=0}^3 mls[j][i] * \lambda_i)_{i\in\{0,..,n\}} + // 2. Evaluating the composition polynomial at the previous openings and checking equality with + // the claimed evaluation. + + // 1. Querying the oracles + + let FinalEvaluationClaim { + evaluation_point, + claimed_evaluation, + polynomial, + } = final_eval_claim; + + // The evaluation of the EQ polynomial can be done by the verifier directly + let eq = (0..gkr_lagrange_kernel_rand.len()) + .map(|i| { + gkr_lagrange_kernel_rand[i] * evaluation_point[i] + + (BaseElement::ONE - gkr_lagrange_kernel_rand[i]) + * (BaseElement::ONE - evaluation_point[i]) + }) + .fold(BaseElement::ONE, |acc, term| acc * term); + + // These are the queries to the oracles. + // They should be provided by the prover non-deterministically + let left_num_eval = mls[0].evaluate(&evaluation_point); + let right_num_eval = mls[1].evaluate(&evaluation_point); + let left_den_eval = mls[2].evaluate(&evaluation_point); + let right_den_eval = mls[3].evaluate(&evaluation_point); + + // The verifier absorbs the claimed openings and generates batching randomness + let mut query = vec![left_num_eval, right_num_eval, left_den_eval, right_den_eval]; + transcript.reseed(Rpo256::hash_elements(&query)); + let lambdas: Vec = vec![ + transcript.draw().unwrap(), + transcript.draw().unwrap(), + transcript.draw().unwrap(), + ]; + let batched_query = + query[0] + query[1] * lambdas[0] + query[2] * lambdas[1] + query[3] * lambdas[2]; + + // The prover generates the Lagrange kernel + let mut rev_evaluation_point = evaluation_point; + rev_evaluation_point.reverse(); + let lagrange_kernel = EqPolynomial::new(rev_evaluation_point).evaluations(); + let tmp_col: Vec = (0..mls[0].len()) + .map(|i| { + mls[0][i] + mls[1][i] * lambdas[0] + mls[2][i] * lambdas[1] + mls[3][i] * lambdas[2] + }) + .collect(); + + // The prover generates the additional auxiliary column for the inner product + let mut running_sum_col = vec![BaseElement::ZERO; tmp_col.len() + 1]; + running_sum_col[0] = BaseElement::ZERO; + for i in 1..(tmp_col.len() + 1) { + running_sum_col[i] = running_sum_col[i - 1] + tmp_col[i - 1] * lagrange_kernel[i - 1]; + } + + // Boundary constraint to check correctness of openings + assert_eq!(batched_query, *running_sum_col.last().unwrap()); + + // 2) Final evaluation and check + query.push(eq); + let verifier_computed = polynomial.composer.evaluate(&query); + + assert_eq!(verifier_computed, claimed_evaluation); +} + +pub fn generate_logup_witness(trace_len: usize) -> Vec> { + let num_variables_ml = trace_len; + let num_evaluations = 1 << num_variables_ml; + let num_witnesses = 1; + let (p, q) = generate_logup_data::(num_variables_ml, num_witnesses); + let numerators: Vec> = p.chunks(num_evaluations).map(|x| x.into()).collect(); + let denominators: Vec> = q.chunks(num_evaluations).map(|x| x.into()).collect(); + + let mut mls = vec![]; + for i in 0..2 { + let ml = MultiLinear::from_values(&numerators[i]); + mls.push(ml); + } + for i in 0..2 { + let ml = MultiLinear::from_values(&denominators[i]); + mls.push(ml); + } + mls +} + +pub fn generate_logup_data( + trace_len: usize, + num_witnesses: usize, +) -> (Vec, Vec) { + use rand::distributions::Slice; + use rand::Rng; + let n: usize = trace_len; + let num_w: usize = num_witnesses; // This should be of the form 2^k - 1 + let rng = rand::rngs::StdRng::seed_from_u64(0); + + let t_table: Vec = (0..(1 << n)).collect(); + let mut m_table: Vec = (0..(1 << n)).map(|_| 0).collect(); + + let t_table_slice = Slice::new(&t_table).unwrap(); + + // Construct the witness columns. Uses sampling with replacement in order to have multiplicities + // different from 1. + let mut w_tables = Vec::new(); + for _ in 0..num_w { + let wi_table: Vec = + rng.clone().sample_iter(&t_table_slice).cloned().take(1 << n).collect(); + + // Construct the multiplicities + wi_table.iter().for_each(|w| { + m_table[*w as usize] += 1; + }); + w_tables.push(wi_table) + } + + // The numerators + let mut p: Vec = m_table.iter().map(|m| E::from(*m as u32)).collect(); + p.extend((0..(num_w * (1 << n))).map(|_| E::from(1_u32)).collect::>()); + + // Construct the denominators + let mut q: Vec = t_table.iter().map(|t| E::from(*t)).collect(); + for w_table in w_tables { + q.extend(w_table.iter().map(|w| E::from(*w)).collect::>()); + } + (p, q) +} diff --git a/src/gkr/sumcheck/verifier.rs b/src/gkr/sumcheck/verifier.rs new file mode 100644 index 0000000..5f9665c --- /dev/null +++ b/src/gkr/sumcheck/verifier.rs @@ -0,0 +1,71 @@ +use winter_crypto::{ElementHasher, RandomCoin}; +use winter_math::{fields::f64::BaseElement, FieldElement}; + +use crate::gkr::utils::{barycentric_weights, evaluate_barycentric}; + +use super::{Claim, FinalEvaluationClaim, FullProof, PartialProof}; + +pub fn sum_check_verify_and_reduce< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + claim: &Claim, + proofs: PartialProof, + coin: &mut C, +) -> (E, Vec) { + let degree = 3; + let points: Vec = (0..degree + 1).map(|x| E::from(x as u8)).collect(); + let mut sum_value = claim.sum_value.clone(); + let mut randomness = vec![]; + + for proof in proofs.round_proofs { + let partial_evals = proof.poly_evals.clone(); + coin.reseed(H::hash_elements(&partial_evals)); + + // get r + let r: E = coin.draw().unwrap(); + randomness.push(r); + let evals = proof.to_evals(sum_value); + + let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect(); + let weights = barycentric_weights(&point_evals); + sum_value = evaluate_barycentric(&point_evals, r, &weights); + } + (sum_value, randomness) +} + +pub fn sum_check_verify< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + claim: &Claim, + proofs: FullProof, + coin: &mut C, +) -> FinalEvaluationClaim { + let FullProof { + sum_check_proof: proofs, + final_evaluation_claim, + } = proofs; + let Claim { mut sum_value, polynomial } = claim; + let degree = polynomial.composer.max_degree(); + let points: Vec = (0..degree + 1).map(|x| E::from(x as u8)).collect(); + + for proof in proofs.round_proofs { + let partial_evals = proof.poly_evals.clone(); + coin.reseed(H::hash_elements(&partial_evals)); + + // get r + let r: E = coin.draw().unwrap(); + let evals = proof.to_evals(sum_value); + + let point_evals: Vec<_> = points.iter().zip(evals.iter()).map(|(x, y)| (*x, *y)).collect(); + let weights = barycentric_weights(&point_evals); + sum_value = evaluate_barycentric(&point_evals, r, &weights); + } + + assert_eq!(final_evaluation_claim.claimed_evaluation, sum_value); + + final_evaluation_claim +} diff --git a/src/gkr/utils/mod.rs b/src/gkr/utils/mod.rs new file mode 100644 index 0000000..7d2d445 --- /dev/null +++ b/src/gkr/utils/mod.rs @@ -0,0 +1,33 @@ +use winter_math::{FieldElement, batch_inversion}; + + +pub fn barycentric_weights(points: &[(E, E)]) -> Vec { + let n = points.len(); + let tmp = (0..n) + .map(|i| (0..n).filter(|&j| j != i).fold(E::ONE, |acc, j| acc * (points[i].0 - points[j].0))) + .collect::>(); + batch_inversion(&tmp) +} + +pub fn evaluate_barycentric( + points: &[(E, E)], + x: E, + barycentric_weights: &[E], +) -> E { + for &(x_i, y_i) in points { + if x_i == x { + return y_i; + } + } + + let l_x: E = points.iter().fold(E::ONE, |acc, &(x_i, _y_i)| acc * (x - x_i)); + + let sum = (0..points.len()).fold(E::ZERO, |acc, i| { + let x_i = points[i].0; + let y_i = points[i].1; + let w_i = barycentric_weights[i]; + acc + (w_i / (x - x_i) * y_i) + }); + + l_x * sum +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 26fb343..50624ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] -#[cfg(not(feature = "std"))] -#[cfg_attr(test, macro_use)] +//#[cfg(not(feature = "std"))] +//#[cfg_attr(test, macro_use)] extern crate alloc; pub mod dsa; @@ -9,6 +9,7 @@ pub mod hash; pub mod merkle; pub mod rand; pub mod utils; +pub mod gkr; // RE-EXPORTS // ================================================================================================