Browse Source

add full recursion (binary) tree

Each node of the tree is verifying:
`((verify left sig OR verify left proof) AND (verify right sig OR verify right proof))`.
and then it generates a new plonky2 proof, which can again be verified
in a node of the next level of the tree.

full binary tree of recursion works
main
arnaucube 6 months ago
commit
55a08ad575
5 changed files with 596 additions and 0 deletions
  1. +2
    -0
      .gitignore
  2. +13
    -0
      Cargo.toml
  3. +9
    -0
      README.md
  4. +14
    -0
      src/lib.rs
  5. +558
    -0
      src/tree_recursion.rs

+ 2
- 0
.gitignore

@ -0,0 +1,2 @@
/target
Cargo.lock

+ 13
- 0
Cargo.toml

@ -0,0 +1,13 @@
[package]
name = "plonky2-recursion-experiment"
version = "0.1.0"
edition = "2021"
[dependencies]
plonky2 = { git = "https://github.com/mir-protocol/plonky2" }
sch = { git = "https://github.com/tideofwords/schnorr" }
anyhow = "1.0.56"
rand = "0.8.5"
hashbrown = { version = "0.14.3", default-features = false, features = ["ahash", "serde"] }
log = { version = "0.4.14", default-features = false }
env_logger = "0.10.0"

+ 9
- 0
README.md

@ -0,0 +1,9 @@
# plonky2-recursion-experiment
WARNING: Experimental. Not safe, do not use in production.
## Run
```bash
rustup override set nightly # Requires nightly Rust
cargo test --release -- --nocapture
```

+ 14
- 0
src/lib.rs

@ -0,0 +1,14 @@
#![allow(clippy::new_without_default)]
#![allow(non_snake_case)]
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
pub mod tree_recursion;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::plonk::config::PoseidonGoldilocksConfig;
use plonky2::plonk::proof::Proof;
pub type F = GoldilocksField;
pub type C = PoseidonGoldilocksConfig;
pub type PlonkyProof = Proof<F, PoseidonGoldilocksConfig, 2>;

+ 558
- 0
src/tree_recursion.rs

@ -0,0 +1,558 @@
/*
Tree recursion with conditionals.
p_7
F
p_5 p_6
F F
p_1 p_2 p_3 p_4
where each p_i is either
- signature verification
- recursive plonky2 proof (proof that verifies previous proof)
(generated by `RecursiveCircuit::prove_step` method)
and F verifies the two incoming p_i's, that is
- (signature proof OR recursive proof) AND (signature proof OR recursive proof)
and produces a new proof.
To run the tests that checks this logic:
cargo test --release test_tree_recursion -- --nocapture
*/
use anyhow::Result;
use plonky2::field::types::Field;
use plonky2::gates::noop::NoopGate;
use plonky2::iop::target::{BoolTarget, Target};
use plonky2::iop::witness::{PartialWitness, WitnessWrite};
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::plonk::circuit_data::{
CircuitConfig, CircuitData, VerifierCircuitData, VerifierCircuitTarget,
};
use plonky2::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget};
use std::time::Instant;
use sch::schnorr::*;
use sch::schnorr_prover::*;
use super::{PlonkyProof, C, F};
/// if s==0: returns x
/// if s==1: returns y
/// Warning: this method assumes all input values are ensured to be \in {0,1}
fn selector_gate(builder: &mut CircuitBuilder<F, 2>, x: Target, y: Target, s: Target) -> Target {
// z = x + s(y-x)
let y_x = builder.sub(y, x);
// z = x+s(y-x) <==> mul_add(s, yx, x)=s*(y-x)+x
builder.mul_add(s, y_x, x)
}
/// ensures b \in {0,1}
fn binary_check(builder: &mut CircuitBuilder<F, 2>, b: Target) {
let zero = builder.zero();
let one = builder.one();
// b * (b-1) == 0
let b_1 = builder.sub(b, one);
let r = builder.mul(b, b_1);
builder.connect(r, zero);
}
/// Contains the methods to `build` (ie. create the targets, the logic of the circuit), and
/// `fill_targets` (ie. set the specific values to be used for the previously created targets).
///
/// The logic of this gadget verifies the given signature if `selector==0`.
/// We reuse this gadget for both the signature verifications of the left & right signatures in the
/// node of the recursion tree.
pub struct SignatureGadgetTargets {
selector_targ: Target,
selector_booltarg: BoolTarget,
pk_targ: SchnorrPublicKeyTarget,
sig_targ: SchnorrSignatureTarget,
}
impl SignatureGadgetTargets {
pub fn build(
mut builder: &mut CircuitBuilder<F, 2>,
msg_targ: &MessageTarget,
) -> Result<SignatureGadgetTargets> {
let selector_targ = builder.add_virtual_target();
// ensure that selector_booltarg is \in {0,1}
binary_check(builder, selector_targ);
let selector_booltarg = BoolTarget::new_unsafe(selector_targ);
// signature verification:
let sb: SchnorrBuilder = SchnorrBuilder {};
let pk_targ = SchnorrPublicKeyTarget::new_virtual(&mut builder);
let sig_targ = SchnorrSignatureTarget::new_virtual(&mut builder);
let sig_verif_targ = sb.verify_sig::<C>(&mut builder, &sig_targ, &msg_targ, &pk_targ);
/*
- if selector=0
verify_sig==1 && proof_enabled=0
- if selector=1
verify_sig==NaN && proof_enabled=1 (don't check the sig)
to disable the verify_sig check, when selector=1:
x=verify_sig, y=always_1, s=selector (all values \in {0,1})
z = x + s(y-x)
*/
// if selector=0: check that sig_verif==1
// if selector=1: check that one==1
let one = builder.one();
let expected = selector_gate(
builder,
sig_verif_targ.target,
one,
selector_booltarg.target,
);
let one_2 = builder.one();
builder.connect(expected, one_2);
Ok(Self {
selector_targ,
selector_booltarg,
pk_targ,
sig_targ,
})
}
pub fn fill_targets(
&mut self,
pw: &mut PartialWitness<F>,
// left side
selector: F, // 1=proof, 0=sig
pk: &SchnorrPublicKey,
sig: &SchnorrSignature,
) -> Result<()> {
pw.set_target(self.selector_targ, selector)?;
// set signature related values:
self.pk_targ.set_witness(pw, &pk).unwrap();
self.sig_targ.set_witness(pw, &sig).unwrap();
Ok(())
}
}
/// Contains the methods to `build` (ie. create the targets, the logic of the circuit), and
/// `fill_targets` (ie. set the specific values to be used for the previously created targets).
pub struct RecursiveCircuit {
msg_targ: MessageTarget,
L_sig_targets: SignatureGadgetTargets,
R_sig_targets: SignatureGadgetTargets,
// L_sig_verif_targ: BoolTarget,
L_proof_targ: ProofWithPublicInputsTarget<2>,
R_proof_targ: ProofWithPublicInputsTarget<2>,
// the next two are common for both L&R proofs. It is the data for this circuit itself (cyclic circuit).
verifier_data_targ: VerifierCircuitTarget,
verifier_data: VerifierCircuitData<F, C, 2>,
}
impl RecursiveCircuit {
pub fn prepare_public_inputs(
verifier_data: VerifierCircuitData<F, C, 2>,
msg: Vec<F>,
) -> Vec<F> {
[
msg.clone(),
// add verifier_data as public inputs:
verifier_data.verifier_only.circuit_digest.elements.to_vec(),
verifier_data
.verifier_only
.constants_sigmas_cap
.0
.iter()
.flat_map(|e| e.elements)
.collect(),
]
.concat()
}
// notice that this method does not fill the targets, which is done in the method
// `fill_recursive_circuit_targets`
pub fn build(
builder: &mut CircuitBuilder<F, 2>,
verifier_data: VerifierCircuitData<F, C, 2>,
msg_len: usize,
) -> Result<Self> {
let msg_targ = MessageTarget::new_with_size(builder, msg_len);
// set msg as public input
builder.register_public_inputs(&msg_targ.msg);
// build the signature verification logic
let L_sig_targets = SignatureGadgetTargets::build(builder, &msg_targ)?;
let R_sig_targets = SignatureGadgetTargets::build(builder, &msg_targ)?;
// proof verification:
let common_data = verifier_data.common.clone();
let verifier_data_targ = builder.add_verifier_data_public_inputs();
let L_proof_targ = builder.add_virtual_proof_with_pis(&common_data);
builder.conditionally_verify_cyclic_proof_or_dummy::<C>(
L_sig_targets.selector_booltarg,
&L_proof_targ,
&common_data,
)?;
let R_proof_targ = builder.add_virtual_proof_with_pis(&common_data);
builder.conditionally_verify_cyclic_proof_or_dummy::<C>(
R_sig_targets.selector_booltarg,
&R_proof_targ,
&common_data,
)?;
Ok(Self {
msg_targ,
L_sig_targets,
R_sig_targets,
L_proof_targ,
R_proof_targ,
verifier_data_targ,
verifier_data,
})
}
pub fn fill_targets(
&mut self,
pw: &mut PartialWitness<F>,
msg: &Vec<F>,
// left side
L_selector: F, // 1=proof, 0=sig
L_pk: &SchnorrPublicKey,
L_sig: &SchnorrSignature,
L_recursive_proof: &PlonkyProof,
// right side
R_selector: F, // 1=proof, 0=sig
R_pk: &SchnorrPublicKey,
R_sig: &SchnorrSignature,
R_recursive_proof: &PlonkyProof,
) -> Result<()> {
// set the msg value (used by both sig gadgets, left and right)
self.msg_targ.set_witness(pw, &msg).unwrap();
// set the signature related values
self.L_sig_targets
.fill_targets(pw, L_selector, L_pk, L_sig)?;
self.R_sig_targets
.fill_targets(pw, R_selector, R_pk, R_sig)?;
// set proof related values:
// recursive proofs verification
pw.set_verifier_data_target(&self.verifier_data_targ, &self.verifier_data.verifier_only)?;
let public_inputs =
RecursiveCircuit::prepare_public_inputs(self.verifier_data.clone(), msg.clone());
// left proof verification values
pw.set_proof_with_pis_target(
&self.L_proof_targ,
&ProofWithPublicInputs {
proof: L_recursive_proof.clone(),
public_inputs: public_inputs.clone(),
},
)?;
// right proof verification values
pw.set_proof_with_pis_target(
&self.R_proof_targ,
&ProofWithPublicInputs {
proof: R_recursive_proof.clone(),
public_inputs,
},
)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Recursion {}
pub fn common_data_for_recursion(msg_len: usize) -> CircuitData<F, C, 2> {
// 1st
let config = CircuitConfig::standard_recursion_config();
let builder = CircuitBuilder::<F, 2>::new(config);
let data = builder.build::<C>();
// 2nd
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, 2>::new(config.clone());
let verifier_data = builder.add_virtual_verifier_data(data.common.config.fri_config.cap_height);
// left proof
let proof = builder.add_virtual_proof_with_pis(&data.common);
builder.verify_proof::<C>(&proof, &verifier_data, &data.common);
// right proof
let proof = builder.add_virtual_proof_with_pis(&data.common);
builder.verify_proof::<C>(&proof, &verifier_data, &data.common);
let data = builder.build::<C>();
// 3rd
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::<F, 2>::new(config.clone());
let msg_targ = MessageTarget::new_with_size(&mut builder, msg_len);
// sigs verify
builder.register_public_inputs(&msg_targ.msg);
builder.add_gate(
// add a ConstantGate, because without this, when later generating the `dummy_circuit`
// (inside the `conditionally_verify_cyclic_proof_or_dummy`), it fails due the
// `CommonCircuitData` of the generated circuit not matching the given `CommonCircuitData`
// to create it. Without this it fails because it misses a ConstantGate.
plonky2::gates::constant::ConstantGate::new(config.num_constants),
vec![],
);
let _ = SignatureGadgetTargets::build(&mut builder, &msg_targ).unwrap();
let _ = SignatureGadgetTargets::build(&mut builder, &msg_targ).unwrap();
// proofs verify
let verifier_data = builder.add_verifier_data_public_inputs();
// left proof
let proof_L = builder.add_virtual_proof_with_pis(&data.common);
builder.verify_proof::<C>(&proof_L, &verifier_data, &data.common);
// right proof
let proof_R = builder.add_virtual_proof_with_pis(&data.common);
builder.verify_proof::<C>(&proof_R, &verifier_data, &data.common);
// pad min gates
while builder.num_gates() < 1 << 13 {
builder.add_gate(NoopGate, vec![]);
}
builder.build::<C>()
}
impl Recursion {
/// returns the full-recursive CircuitData
pub fn circuit_data(msg_len: usize) -> Result<CircuitData<F, C, 2>> {
let mut data = common_data_for_recursion(msg_len);
// build the actual RecursiveCircuit circuit data
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
let _ = RecursiveCircuit::build(&mut builder, data.verifier_data(), msg_len)?;
data = builder.build::<C>();
Ok(data)
}
pub fn prove_step(
verifier_data: VerifierCircuitData<F, C, 2>,
msg: &Vec<F>,
// left side
L_selector: F, // 1=proof, 0=sig
pk_L: &SchnorrPublicKey,
sig_L: &SchnorrSignature,
recursive_proof_L: &PlonkyProof,
// right side
R_selector: F, // 1=proof, 0=sig
pk_R: &SchnorrPublicKey,
sig_R: &SchnorrSignature,
recursive_proof_R: &PlonkyProof,
) -> Result<PlonkyProof> {
println!("prove_step:");
if L_selector.is_nonzero() {
println!(" (L_selector==1), verify left proof");
} else {
println!(" (L_selector==0), verify left signature");
}
if R_selector.is_nonzero() {
println!(" (R_selector==1), verify right proof");
} else {
println!(" (R_selector==0), verify right signature");
}
let config = CircuitConfig::standard_recursion_config();
let mut builder = CircuitBuilder::new(config);
// assign the targets
let start = Instant::now();
let mut circuit = RecursiveCircuit::build(&mut builder, verifier_data.clone(), msg.len())?;
println!("RecursiveCircuit::build(): {:?}", start.elapsed());
// fill the targets
let mut pw = PartialWitness::new();
let start = Instant::now();
circuit.fill_targets(
&mut pw,
msg,
L_selector,
pk_L,
sig_L,
recursive_proof_L,
R_selector,
pk_R,
sig_R,
recursive_proof_R,
)?;
println!("circuit.fill_targets(): {:?}", start.elapsed());
let start = Instant::now();
let data = builder.build::<C>();
println!("builder.build(): {:?}", start.elapsed());
let start = Instant::now();
let new_proof = data.prove(pw)?;
println!("generate new_proof: {:?}", start.elapsed());
let start = Instant::now();
data.verify(new_proof.clone())?;
println!("verify new_proof: {:?}", start.elapsed());
#[cfg(test)]
data.verifier_data().verify(ProofWithPublicInputs {
proof: new_proof.proof.clone(),
public_inputs: new_proof.public_inputs.clone(),
})?;
#[cfg(test)]
verifier_data.verify(ProofWithPublicInputs {
proof: new_proof.proof.clone(),
public_inputs: new_proof.public_inputs.clone(),
})?;
Ok(new_proof.proof)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;
use hashbrown::HashMap;
use plonky2::field::types::Field;
use plonky2::plonk::proof::ProofWithPublicInputs;
use plonky2::recursion::dummy_circuit::cyclic_base_proof;
use rand;
use std::time::Instant;
use super::*;
// this sets the plonky2 internal logs level
fn set_log() {
let _ = env_logger::builder()
.filter_level(log::LevelFilter::Warn)
.is_test(true)
.try_init();
}
/// to run:
/// cargo test --release test_tree_recursion -- --nocapture
#[test]
fn test_tree_recursion() -> Result<()> {
set_log();
let mut rng: rand::rngs::ThreadRng = rand::thread_rng();
let schnorr = SchnorrSigner::new();
const MSG_LEN: usize = 5;
let msg: Vec<F> = schnorr.u64_into_goldilocks_vec(vec![1500, 1600, 2, 2, 2]);
assert_eq!(msg.len(), MSG_LEN);
let l: u32 = 2; // levels of the recursion (binary) tree
let k = 2_u32.pow(l); // number of leafs in the recursion tree
// generate k key pairs
let sk_vec: Vec<SchnorrSecretKey> =
(0..k).map(|i| SchnorrSecretKey { sk: i as u64 }).collect();
let pk_vec: Vec<SchnorrPublicKey> = sk_vec.iter().map(|&sk| schnorr.keygen(&sk)).collect();
let sig_vec: Vec<SchnorrSignature> = sk_vec
.iter()
.map(|&sk| schnorr.sign(&msg, &sk, &mut rng))
.collect();
// build the circuit_data & verifier_data for the recursive circuit
let circuit_data = Recursion::circuit_data(MSG_LEN)?;
let verifier_data = circuit_data.verifier_data();
// let dummy_circuit = dummy_circuit::<F, C, 2>(&circuit_data.common); // WIP
// let dummy_proof_pis = dummy_proof(&dummy_circuit, HashMap::new())?; // WIP
let dummy_proof_pis = cyclic_base_proof(
&circuit_data.common,
&verifier_data.verifier_only,
HashMap::new(),
);
let dummy_proof = dummy_proof_pis.proof;
// we start with k dummy proofs, since at the leafs level we don't have proofs yet and we
// just verify the signatures
let mut proofs_at_level_i: Vec<PlonkyProof> =
(0..k).into_iter().map(|_| dummy_proof.clone()).collect();
// loop over the recursion levels
for i in 0..l {
println!("\n=== recursion level i={}", i);
let mut next_level_proofs: Vec<PlonkyProof> = vec![];
// loop over the nodes of each recursion tree level
for j in (0..proofs_at_level_i.len()).into_iter().step_by(2) {
println!("\n------ recursion node i={}, j={}", i, j);
// - if we're at the first level of the recursion tree:
// proof_enabled=false=0, so that the circuit verifies the signature and not the proof.
// - else:
// proof_enabled=true=1, so that the circuit verifies the proof and not the signature.
//
// In future tests we will try other cases (eg. left sig, right proof), but for
// the moment we just do base_case: sig verify, other cases: proof verify.
let proof_enabled = if i == 0 { F::ZERO } else { F::ONE };
// do the recursive step
let start = Instant::now();
let new_proof = Recursion::prove_step(
verifier_data.clone(),
&msg,
// left side:
proof_enabled,
&pk_vec[j],
&sig_vec[j],
&proofs_at_level_i[j],
// right side
proof_enabled,
&pk_vec[j + 1],
&sig_vec[j + 1],
&proofs_at_level_i[j + 1],
)?;
println!(
"Recursion::prove_step (level: i={}, node: j={}) took: {:?}",
i,
j,
start.elapsed()
);
// verify the recursive proof
let public_inputs =
RecursiveCircuit::prepare_public_inputs(verifier_data.clone(), msg.clone());
verifier_data.clone().verify(ProofWithPublicInputs {
proof: new_proof.clone(),
public_inputs: public_inputs.clone(),
})?;
// set new_proof for next iteration
next_level_proofs.push(new_proof);
}
proofs_at_level_i = next_level_proofs.clone();
}
assert_eq!(proofs_at_level_i.len(), 1);
let last_proof = proofs_at_level_i[0].clone();
// verify the last proof
let public_inputs =
RecursiveCircuit::prepare_public_inputs(verifier_data.clone(), msg.clone());
verifier_data.clone().verify(ProofWithPublicInputs {
proof: last_proof.clone(),
public_inputs: public_inputs.clone(),
})?;
Ok(())
}
// WIP will add more tests with other sig/proof combinations
}

Loading…
Cancel
Save