commit 55a08ad575ac43d71629b4a784cb696ff26496fc Author: arnaucube Date: Tue Oct 15 16:51:49 2024 +0200 add full recursion (binary) tree Each node of the tree is verifying: `((verify left sig OR verify left proof) AND (verify right sig OR verify right proof))`. and then it generates a new plonky2 proof, which can again be verified in a node of the next level of the tree. full binary tree of recursion works diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..1065975 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "plonky2-recursion-experiment" +version = "0.1.0" +edition = "2021" + +[dependencies] +plonky2 = { git = "https://github.com/mir-protocol/plonky2" } +sch = { git = "https://github.com/tideofwords/schnorr" } +anyhow = "1.0.56" +rand = "0.8.5" +hashbrown = { version = "0.14.3", default-features = false, features = ["ahash", "serde"] } +log = { version = "0.4.14", default-features = false } +env_logger = "0.10.0" diff --git a/README.md b/README.md new file mode 100644 index 0000000..a98a0e3 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# plonky2-recursion-experiment +WARNING: Experimental. Not safe, do not use in production. + +## Run + +```bash +rustup override set nightly # Requires nightly Rust +cargo test --release -- --nocapture +``` diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ded4537 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,14 @@ +#![allow(clippy::new_without_default)] +#![allow(non_snake_case)] +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] + +pub mod tree_recursion; + +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::config::PoseidonGoldilocksConfig; +use plonky2::plonk::proof::Proof; + +pub type F = GoldilocksField; +pub type C = PoseidonGoldilocksConfig; +pub type PlonkyProof = Proof; diff --git a/src/tree_recursion.rs b/src/tree_recursion.rs new file mode 100644 index 0000000..fcc6eaf --- /dev/null +++ b/src/tree_recursion.rs @@ -0,0 +1,558 @@ +/* + Tree recursion with conditionals. + + + p_7 + ▲ + │ + ┌─┴─┐ + │ F │ + └───┘ + ▲ ▲ + ┌─┘ └─┐ + ┌───┘ └───┐ + │p_5 │p_6 + ┌─┴─┐ ┌─┴─┐ + │ F │ │ F │ + └───┘ └───┘ + ▲ ▲ ▲ ▲ + ┌─┘ └─┐ ┌─┘ └─┐ + │ │ │ │ + p_1 p_2 p_3 p_4 + + where each p_i is either + - signature verification + - recursive plonky2 proof (proof that verifies previous proof) + (generated by `RecursiveCircuit::prove_step` method) + + and F verifies the two incoming p_i's, that is + - (signature proof OR recursive proof) AND (signature proof OR recursive proof) + and produces a new proof. + + + To run the tests that checks this logic: + cargo test --release test_tree_recursion -- --nocapture +*/ + +use anyhow::Result; +use plonky2::field::types::Field; +use plonky2::gates::noop::NoopGate; +use plonky2::iop::target::{BoolTarget, Target}; +use plonky2::iop::witness::{PartialWitness, WitnessWrite}; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2::plonk::circuit_data::{ + CircuitConfig, CircuitData, VerifierCircuitData, VerifierCircuitTarget, +}; +use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use std::time::Instant; + +use sch::schnorr::*; +use sch::schnorr_prover::*; + +use super::{PlonkyProof, C, F}; + +/// if s==0: returns x +/// if s==1: returns y +/// Warning: this method assumes all input values are ensured to be \in {0,1} +fn selector_gate(builder: &mut CircuitBuilder, x: Target, y: Target, s: Target) -> Target { + // z = x + s(y-x) + let y_x = builder.sub(y, x); + // z = x+s(y-x) <==> mul_add(s, yx, x)=s*(y-x)+x + builder.mul_add(s, y_x, x) +} + +/// ensures b \in {0,1} +fn binary_check(builder: &mut CircuitBuilder, b: Target) { + let zero = builder.zero(); + let one = builder.one(); + // b * (b-1) == 0 + let b_1 = builder.sub(b, one); + let r = builder.mul(b, b_1); + builder.connect(r, zero); +} + +/// Contains the methods to `build` (ie. create the targets, the logic of the circuit), and +/// `fill_targets` (ie. set the specific values to be used for the previously created targets). +/// +/// The logic of this gadget verifies the given signature if `selector==0`. +/// We reuse this gadget for both the signature verifications of the left & right signatures in the +/// node of the recursion tree. +pub struct SignatureGadgetTargets { + selector_targ: Target, + selector_booltarg: BoolTarget, + + pk_targ: SchnorrPublicKeyTarget, + sig_targ: SchnorrSignatureTarget, +} +impl SignatureGadgetTargets { + pub fn build( + mut builder: &mut CircuitBuilder, + msg_targ: &MessageTarget, + ) -> Result { + let selector_targ = builder.add_virtual_target(); + // ensure that selector_booltarg is \in {0,1} + binary_check(builder, selector_targ); + let selector_booltarg = BoolTarget::new_unsafe(selector_targ); + + // signature verification: + let sb: SchnorrBuilder = SchnorrBuilder {}; + let pk_targ = SchnorrPublicKeyTarget::new_virtual(&mut builder); + let sig_targ = SchnorrSignatureTarget::new_virtual(&mut builder); + let sig_verif_targ = sb.verify_sig::(&mut builder, &sig_targ, &msg_targ, &pk_targ); + + /* + - if selector=0 + verify_sig==1 && proof_enabled=0 + - if selector=1 + verify_sig==NaN && proof_enabled=1 (don't check the sig) + + to disable the verify_sig check, when selector=1: + x=verify_sig, y=always_1, s=selector (all values \in {0,1}) + z = x + s(y-x) + */ + // if selector=0: check that sig_verif==1 + // if selector=1: check that one==1 + let one = builder.one(); + let expected = selector_gate( + builder, + sig_verif_targ.target, + one, + selector_booltarg.target, + ); + let one_2 = builder.one(); + builder.connect(expected, one_2); + + Ok(Self { + selector_targ, + selector_booltarg, + pk_targ, + sig_targ, + }) + } + pub fn fill_targets( + &mut self, + pw: &mut PartialWitness, + // left side + selector: F, // 1=proof, 0=sig + pk: &SchnorrPublicKey, + sig: &SchnorrSignature, + ) -> Result<()> { + pw.set_target(self.selector_targ, selector)?; + + // set signature related values: + self.pk_targ.set_witness(pw, &pk).unwrap(); + self.sig_targ.set_witness(pw, &sig).unwrap(); + + Ok(()) + } +} + +/// Contains the methods to `build` (ie. create the targets, the logic of the circuit), and +/// `fill_targets` (ie. set the specific values to be used for the previously created targets). +pub struct RecursiveCircuit { + msg_targ: MessageTarget, + L_sig_targets: SignatureGadgetTargets, + R_sig_targets: SignatureGadgetTargets, + // L_sig_verif_targ: BoolTarget, + L_proof_targ: ProofWithPublicInputsTarget<2>, + R_proof_targ: ProofWithPublicInputsTarget<2>, + // the next two are common for both L&R proofs. It is the data for this circuit itself (cyclic circuit). + verifier_data_targ: VerifierCircuitTarget, + verifier_data: VerifierCircuitData, +} + +impl RecursiveCircuit { + pub fn prepare_public_inputs( + verifier_data: VerifierCircuitData, + msg: Vec, + ) -> Vec { + [ + msg.clone(), + // add verifier_data as public inputs: + verifier_data.verifier_only.circuit_digest.elements.to_vec(), + verifier_data + .verifier_only + .constants_sigmas_cap + .0 + .iter() + .flat_map(|e| e.elements) + .collect(), + ] + .concat() + } + // notice that this method does not fill the targets, which is done in the method + // `fill_recursive_circuit_targets` + pub fn build( + builder: &mut CircuitBuilder, + verifier_data: VerifierCircuitData, + msg_len: usize, + ) -> Result { + let msg_targ = MessageTarget::new_with_size(builder, msg_len); + // set msg as public input + builder.register_public_inputs(&msg_targ.msg); + + // build the signature verification logic + let L_sig_targets = SignatureGadgetTargets::build(builder, &msg_targ)?; + let R_sig_targets = SignatureGadgetTargets::build(builder, &msg_targ)?; + + // proof verification: + + let common_data = verifier_data.common.clone(); + let verifier_data_targ = builder.add_verifier_data_public_inputs(); + + let L_proof_targ = builder.add_virtual_proof_with_pis(&common_data); + builder.conditionally_verify_cyclic_proof_or_dummy::( + L_sig_targets.selector_booltarg, + &L_proof_targ, + &common_data, + )?; + let R_proof_targ = builder.add_virtual_proof_with_pis(&common_data); + builder.conditionally_verify_cyclic_proof_or_dummy::( + R_sig_targets.selector_booltarg, + &R_proof_targ, + &common_data, + )?; + + Ok(Self { + msg_targ, + L_sig_targets, + R_sig_targets, + L_proof_targ, + R_proof_targ, + verifier_data_targ, + verifier_data, + }) + } + + pub fn fill_targets( + &mut self, + pw: &mut PartialWitness, + msg: &Vec, + // left side + L_selector: F, // 1=proof, 0=sig + L_pk: &SchnorrPublicKey, + L_sig: &SchnorrSignature, + L_recursive_proof: &PlonkyProof, + // right side + R_selector: F, // 1=proof, 0=sig + R_pk: &SchnorrPublicKey, + R_sig: &SchnorrSignature, + R_recursive_proof: &PlonkyProof, + ) -> Result<()> { + // set the msg value (used by both sig gadgets, left and right) + self.msg_targ.set_witness(pw, &msg).unwrap(); + + // set the signature related values + self.L_sig_targets + .fill_targets(pw, L_selector, L_pk, L_sig)?; + self.R_sig_targets + .fill_targets(pw, R_selector, R_pk, R_sig)?; + + // set proof related values: + + // recursive proofs verification + pw.set_verifier_data_target(&self.verifier_data_targ, &self.verifier_data.verifier_only)?; + + let public_inputs = + RecursiveCircuit::prepare_public_inputs(self.verifier_data.clone(), msg.clone()); + // left proof verification values + pw.set_proof_with_pis_target( + &self.L_proof_targ, + &ProofWithPublicInputs { + proof: L_recursive_proof.clone(), + public_inputs: public_inputs.clone(), + }, + )?; + // right proof verification values + pw.set_proof_with_pis_target( + &self.R_proof_targ, + &ProofWithPublicInputs { + proof: R_recursive_proof.clone(), + public_inputs, + }, + )?; + + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct Recursion {} + +pub fn common_data_for_recursion(msg_len: usize) -> CircuitData { + // 1st + let config = CircuitConfig::standard_recursion_config(); + let builder = CircuitBuilder::::new(config); + let data = builder.build::(); + + // 2nd + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config.clone()); + let verifier_data = builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height); + // left proof + let proof = builder.add_virtual_proof_with_pis(&data.common); + builder.verify_proof::(&proof, &verifier_data, &data.common); + // right proof + let proof = builder.add_virtual_proof_with_pis(&data.common); + builder.verify_proof::(&proof, &verifier_data, &data.common); + let data = builder.build::(); + + // 3rd + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config.clone()); + let msg_targ = MessageTarget::new_with_size(&mut builder, msg_len); + // sigs verify + builder.register_public_inputs(&msg_targ.msg); + + builder.add_gate( + // add a ConstantGate, because without this, when later generating the `dummy_circuit` + // (inside the `conditionally_verify_cyclic_proof_or_dummy`), it fails due the + // `CommonCircuitData` of the generated circuit not matching the given `CommonCircuitData` + // to create it. Without this it fails because it misses a ConstantGate. + plonky2::gates::constant::ConstantGate::new(config.num_constants), + vec![], + ); + + let _ = SignatureGadgetTargets::build(&mut builder, &msg_targ).unwrap(); + let _ = SignatureGadgetTargets::build(&mut builder, &msg_targ).unwrap(); + + // proofs verify + let verifier_data = builder.add_verifier_data_public_inputs(); + // left proof + let proof_L = builder.add_virtual_proof_with_pis(&data.common); + builder.verify_proof::(&proof_L, &verifier_data, &data.common); + // right proof + let proof_R = builder.add_virtual_proof_with_pis(&data.common); + builder.verify_proof::(&proof_R, &verifier_data, &data.common); + + // pad min gates + while builder.num_gates() < 1 << 13 { + builder.add_gate(NoopGate, vec![]); + } + builder.build::() +} + +impl Recursion { + /// returns the full-recursive CircuitData + pub fn circuit_data(msg_len: usize) -> Result> { + let mut data = common_data_for_recursion(msg_len); + + // build the actual RecursiveCircuit circuit data + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + let _ = RecursiveCircuit::build(&mut builder, data.verifier_data(), msg_len)?; + data = builder.build::(); + + Ok(data) + } + + pub fn prove_step( + verifier_data: VerifierCircuitData, + msg: &Vec, + // left side + L_selector: F, // 1=proof, 0=sig + pk_L: &SchnorrPublicKey, + sig_L: &SchnorrSignature, + recursive_proof_L: &PlonkyProof, + // right side + R_selector: F, // 1=proof, 0=sig + pk_R: &SchnorrPublicKey, + sig_R: &SchnorrSignature, + recursive_proof_R: &PlonkyProof, + ) -> Result { + println!("prove_step:"); + if L_selector.is_nonzero() { + println!(" (L_selector==1), verify left proof"); + } else { + println!(" (L_selector==0), verify left signature"); + } + if R_selector.is_nonzero() { + println!(" (R_selector==1), verify right proof"); + } else { + println!(" (R_selector==0), verify right signature"); + } + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::new(config); + + // assign the targets + let start = Instant::now(); + let mut circuit = RecursiveCircuit::build(&mut builder, verifier_data.clone(), msg.len())?; + println!("RecursiveCircuit::build(): {:?}", start.elapsed()); + + // fill the targets + let mut pw = PartialWitness::new(); + let start = Instant::now(); + circuit.fill_targets( + &mut pw, + msg, + L_selector, + pk_L, + sig_L, + recursive_proof_L, + R_selector, + pk_R, + sig_R, + recursive_proof_R, + )?; + println!("circuit.fill_targets(): {:?}", start.elapsed()); + + let start = Instant::now(); + let data = builder.build::(); + println!("builder.build(): {:?}", start.elapsed()); + + let start = Instant::now(); + let new_proof = data.prove(pw)?; + println!("generate new_proof: {:?}", start.elapsed()); + + let start = Instant::now(); + data.verify(new_proof.clone())?; + println!("verify new_proof: {:?}", start.elapsed()); + + #[cfg(test)] + data.verifier_data().verify(ProofWithPublicInputs { + proof: new_proof.proof.clone(), + public_inputs: new_proof.public_inputs.clone(), + })?; + + #[cfg(test)] + verifier_data.verify(ProofWithPublicInputs { + proof: new_proof.proof.clone(), + public_inputs: new_proof.public_inputs.clone(), + })?; + + Ok(new_proof.proof) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use hashbrown::HashMap; + use plonky2::field::types::Field; + use plonky2::plonk::proof::ProofWithPublicInputs; + use plonky2::recursion::dummy_circuit::cyclic_base_proof; + use rand; + use std::time::Instant; + + use super::*; + + // this sets the plonky2 internal logs level + fn set_log() { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Warn) + .is_test(true) + .try_init(); + } + + /// to run: + /// cargo test --release test_tree_recursion -- --nocapture + #[test] + fn test_tree_recursion() -> Result<()> { + set_log(); + let mut rng: rand::rngs::ThreadRng = rand::thread_rng(); + let schnorr = SchnorrSigner::new(); + const MSG_LEN: usize = 5; + let msg: Vec = schnorr.u64_into_goldilocks_vec(vec![1500, 1600, 2, 2, 2]); + assert_eq!(msg.len(), MSG_LEN); + + let l: u32 = 2; // levels of the recursion (binary) tree + let k = 2_u32.pow(l); // number of leafs in the recursion tree + + // generate k key pairs + let sk_vec: Vec = + (0..k).map(|i| SchnorrSecretKey { sk: i as u64 }).collect(); + let pk_vec: Vec = sk_vec.iter().map(|&sk| schnorr.keygen(&sk)).collect(); + + let sig_vec: Vec = sk_vec + .iter() + .map(|&sk| schnorr.sign(&msg, &sk, &mut rng)) + .collect(); + + // build the circuit_data & verifier_data for the recursive circuit + let circuit_data = Recursion::circuit_data(MSG_LEN)?; + let verifier_data = circuit_data.verifier_data(); + + // let dummy_circuit = dummy_circuit::(&circuit_data.common); // WIP + // let dummy_proof_pis = dummy_proof(&dummy_circuit, HashMap::new())?; // WIP + let dummy_proof_pis = cyclic_base_proof( + &circuit_data.common, + &verifier_data.verifier_only, + HashMap::new(), + ); + let dummy_proof = dummy_proof_pis.proof; + + // we start with k dummy proofs, since at the leafs level we don't have proofs yet and we + // just verify the signatures + let mut proofs_at_level_i: Vec = + (0..k).into_iter().map(|_| dummy_proof.clone()).collect(); + + // loop over the recursion levels + for i in 0..l { + println!("\n=== recursion level i={}", i); + let mut next_level_proofs: Vec = vec![]; + + // loop over the nodes of each recursion tree level + for j in (0..proofs_at_level_i.len()).into_iter().step_by(2) { + println!("\n------ recursion node i={}, j={}", i, j); + + // - if we're at the first level of the recursion tree: + // proof_enabled=false=0, so that the circuit verifies the signature and not the proof. + // - else: + // proof_enabled=true=1, so that the circuit verifies the proof and not the signature. + // + // In future tests we will try other cases (eg. left sig, right proof), but for + // the moment we just do base_case: sig verify, other cases: proof verify. + let proof_enabled = if i == 0 { F::ZERO } else { F::ONE }; + + // do the recursive step + let start = Instant::now(); + let new_proof = Recursion::prove_step( + verifier_data.clone(), + &msg, + // left side: + proof_enabled, + &pk_vec[j], + &sig_vec[j], + &proofs_at_level_i[j], + // right side + proof_enabled, + &pk_vec[j + 1], + &sig_vec[j + 1], + &proofs_at_level_i[j + 1], + )?; + println!( + "Recursion::prove_step (level: i={}, node: j={}) took: {:?}", + i, + j, + start.elapsed() + ); + + // verify the recursive proof + let public_inputs = + RecursiveCircuit::prepare_public_inputs(verifier_data.clone(), msg.clone()); + verifier_data.clone().verify(ProofWithPublicInputs { + proof: new_proof.clone(), + public_inputs: public_inputs.clone(), + })?; + + // set new_proof for next iteration + next_level_proofs.push(new_proof); + } + proofs_at_level_i = next_level_proofs.clone(); + } + assert_eq!(proofs_at_level_i.len(), 1); + let last_proof = proofs_at_level_i[0].clone(); + + // verify the last proof + let public_inputs = + RecursiveCircuit::prepare_public_inputs(verifier_data.clone(), msg.clone()); + verifier_data.clone().verify(ProofWithPublicInputs { + proof: last_proof.clone(), + public_inputs: public_inputs.clone(), + })?; + + Ok(()) + } + // WIP will add more tests with other sig/proof combinations +}