From dd8dacb53b25a4c4f48481e3af5b49b94904dbfc Mon Sep 17 00:00:00 2001 From: Vu Vo Date: Thu, 6 Jun 2024 19:41:50 +0700 Subject: [PATCH] feat (circom): allow to define the step_native in Rust (#105) * create function pointer * custom logic via function pointer * fmt * clippy * rust-version * update review code * fmt --- folding-schemes/src/frontend/circom/mod.rs | 134 ++++++++++++++++----- 1 file changed, 102 insertions(+), 32 deletions(-) diff --git a/folding-schemes/src/frontend/circom/mod.rs b/folding-schemes/src/frontend/circom/mod.rs index 6d05555..9259a94 100644 --- a/folding-schemes/src/frontend/circom/mod.rs +++ b/folding-schemes/src/frontend/circom/mod.rs @@ -10,17 +10,85 @@ use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisE use ark_std::fmt::Debug; use num_bigint::BigInt; use std::path::PathBuf; +use std::rc::Rc; +use std::{fmt, usize}; pub mod utils; use utils::CircomWrapper; +type ClosurePointer = Rc, Vec) -> Result, Error>>; + +#[derive(Clone)] +struct CustomStepNative { + func: ClosurePointer, +} + +impl fmt::Debug for CustomStepNative { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Function pointer: {:?}", + std::any::type_name::, Vec) -> Result, Error>>() + ) + } +} + /// Define CircomFCircuit #[derive(Clone, Debug)] pub struct CircomFCircuit { circom_wrapper: CircomWrapper, - state_len: usize, - external_inputs_len: usize, + pub state_len: usize, + pub external_inputs_len: usize, r1cs: CircomR1CS, + custom_step_native_code: Option>, +} + +impl CircomFCircuit { + pub fn set_custom_step_native(&mut self, func: ClosurePointer) { + self.custom_step_native_code = Some(CustomStepNative:: { func }); + } + + pub fn execute_custom_step_native( + &self, + _i: usize, + z_i: Vec, + external_inputs: Vec, + ) -> Result, Error> { + if let Some(code) = &self.custom_step_native_code { + (code.func)(_i, z_i, external_inputs) + } else { + #[cfg(test)] + assert_eq!(z_i.len(), self.state_len()); + #[cfg(test)] + assert_eq!(external_inputs.len(), self.external_inputs_len()); + + let inputs_bi = z_i + .iter() + .map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val)) + .collect::>(); + let mut inputs_map = vec![("ivc_input".to_string(), inputs_bi)]; + + if self.external_inputs_len() > 0 { + let external_inputs_bi = external_inputs + .iter() + .map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val)) + .collect::>(); + inputs_map.push(("external_inputs".to_string(), external_inputs_bi)); + } + + // Computes witness + let witness = self + .circom_wrapper + .extract_witness(&inputs_map) + .map_err(|e| { + Error::WitnessCalculationError(format!("Failed to calculate witness: {}", e)) + })?; + + // Extracts the z_i1(next state) from the witness vector. + let z_i1 = witness[1..1 + self.state_len()].to_vec(); + Ok(z_i1) + } + } } impl FCircuit for CircomFCircuit { @@ -37,6 +105,7 @@ impl FCircuit for CircomFCircuit { state_len, external_inputs_len, r1cs, + custom_step_native_code: None, }) } @@ -53,36 +122,7 @@ impl FCircuit for CircomFCircuit { z_i: Vec, external_inputs: Vec, ) -> Result, Error> { - #[cfg(test)] - assert_eq!(z_i.len(), self.state_len()); - #[cfg(test)] - assert_eq!(external_inputs.len(), self.external_inputs_len()); - - let inputs_bi = z_i - .iter() - .map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val)) - .collect::>(); - let mut inputs_map = vec![("ivc_input".to_string(), inputs_bi)]; - - if self.external_inputs_len() > 0 { - let external_inputs_bi = external_inputs - .iter() - .map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val)) - .collect::>(); - inputs_map.push(("external_inputs".to_string(), external_inputs_bi)); - } - - // Computes witness - let witness = self - .circom_wrapper - .extract_witness(&inputs_map) - .map_err(|e| { - Error::WitnessCalculationError(format!("Failed to calculate witness: {}", e)) - })?; - - // Extracts the z_i1(next state) from the witness vector. - let z_i1 = witness[1..1 + self.state_len()].to_vec(); - Ok(z_i1) + self.execute_custom_step_native(_i, z_i, external_inputs) } fn generate_step_constraints( @@ -305,4 +345,34 @@ pub mod tests { // Disable check for now // assert!(z_i1_var.is_err()) } + + #[test] + fn test_custom_code() { + let r1cs_path = PathBuf::from("./src/frontend/circom/test_folder/cubic_circuit.r1cs"); + let wasm_path = + PathBuf::from("./src/frontend/circom/test_folder/cubic_circuit_js/cubic_circuit.wasm"); + + let mut circom_fcircuit = CircomFCircuit::::new((r1cs_path, wasm_path, 1, 0)).unwrap(); // state_len:1, external_inputs_len:0 + + circom_fcircuit.set_custom_step_native(Rc::new(|_i, z_i, _external| { + let z = z_i[0]; + Ok(vec![z * z * z + z + Fr::from(5)]) + })); + + // Allocates z_i1 by using step_native function. + let z_i = vec![Fr::from(3_u32)]; + let wrapper_circuit = crate::frontend::tests::WrapperCircuit { + FC: circom_fcircuit.clone(), + z_i: Some(z_i.clone()), + z_i1: Some(circom_fcircuit.step_native(0, z_i.clone(), vec![]).unwrap()), + }; + + let cs = ConstraintSystem::::new_ref(); + + wrapper_circuit.generate_constraints(cs.clone()).unwrap(); + assert!( + cs.is_satisfied().unwrap(), + "Constraint system is not satisfied" + ); + } }