Browse Source

feat (circom): allow to define the step_native in Rust (#105)

* create function pointer

* custom logic via function pointer

* fmt

* clippy

* rust-version

* update review code

* fmt
main
Vu Vo 6 months ago
committed by GitHub
parent
commit
dd8dacb53b
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 102 additions and 32 deletions
  1. +102
    -32
      folding-schemes/src/frontend/circom/mod.rs

+ 102
- 32
folding-schemes/src/frontend/circom/mod.rs

@ -10,17 +10,85 @@ use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystemRef, SynthesisE
use ark_std::fmt::Debug; use ark_std::fmt::Debug;
use num_bigint::BigInt; use num_bigint::BigInt;
use std::path::PathBuf; use std::path::PathBuf;
use std::rc::Rc;
use std::{fmt, usize};
pub mod utils; pub mod utils;
use utils::CircomWrapper; use utils::CircomWrapper;
type ClosurePointer<F> = Rc<dyn Fn(usize, Vec<F>, Vec<F>) -> Result<Vec<F>, Error>>;
#[derive(Clone)]
struct CustomStepNative<F: PrimeField> {
func: ClosurePointer<F>,
}
impl<F: PrimeField> fmt::Debug for CustomStepNative<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Function pointer: {:?}",
std::any::type_name::<fn(usize, Vec<F>, Vec<F>) -> Result<Vec<F>, Error>>()
)
}
}
/// Define CircomFCircuit /// Define CircomFCircuit
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct CircomFCircuit<F: PrimeField> { pub struct CircomFCircuit<F: PrimeField> {
circom_wrapper: CircomWrapper<F>, circom_wrapper: CircomWrapper<F>,
state_len: usize,
external_inputs_len: usize,
pub state_len: usize,
pub external_inputs_len: usize,
r1cs: CircomR1CS<F>, r1cs: CircomR1CS<F>,
custom_step_native_code: Option<CustomStepNative<F>>,
}
impl<F: PrimeField> CircomFCircuit<F> {
pub fn set_custom_step_native(&mut self, func: ClosurePointer<F>) {
self.custom_step_native_code = Some(CustomStepNative::<F> { func });
}
pub fn execute_custom_step_native(
&self,
_i: usize,
z_i: Vec<F>,
external_inputs: Vec<F>,
) -> Result<Vec<F>, Error> {
if let Some(code) = &self.custom_step_native_code {
(code.func)(_i, z_i, external_inputs)
} else {
#[cfg(test)]
assert_eq!(z_i.len(), self.state_len());
#[cfg(test)]
assert_eq!(external_inputs.len(), self.external_inputs_len());
let inputs_bi = z_i
.iter()
.map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val))
.collect::<Vec<BigInt>>();
let mut inputs_map = vec![("ivc_input".to_string(), inputs_bi)];
if self.external_inputs_len() > 0 {
let external_inputs_bi = external_inputs
.iter()
.map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val))
.collect::<Vec<BigInt>>();
inputs_map.push(("external_inputs".to_string(), external_inputs_bi));
}
// Computes witness
let witness = self
.circom_wrapper
.extract_witness(&inputs_map)
.map_err(|e| {
Error::WitnessCalculationError(format!("Failed to calculate witness: {}", e))
})?;
// Extracts the z_i1(next state) from the witness vector.
let z_i1 = witness[1..1 + self.state_len()].to_vec();
Ok(z_i1)
}
}
} }
impl<F: PrimeField> FCircuit<F> for CircomFCircuit<F> { impl<F: PrimeField> FCircuit<F> for CircomFCircuit<F> {
@ -37,6 +105,7 @@ impl FCircuit for CircomFCircuit {
state_len, state_len,
external_inputs_len, external_inputs_len,
r1cs, r1cs,
custom_step_native_code: None,
}) })
} }
@ -53,36 +122,7 @@ impl FCircuit for CircomFCircuit {
z_i: Vec<F>, z_i: Vec<F>,
external_inputs: Vec<F>, external_inputs: Vec<F>,
) -> Result<Vec<F>, Error> { ) -> Result<Vec<F>, Error> {
#[cfg(test)]
assert_eq!(z_i.len(), self.state_len());
#[cfg(test)]
assert_eq!(external_inputs.len(), self.external_inputs_len());
let inputs_bi = z_i
.iter()
.map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val))
.collect::<Vec<BigInt>>();
let mut inputs_map = vec![("ivc_input".to_string(), inputs_bi)];
if self.external_inputs_len() > 0 {
let external_inputs_bi = external_inputs
.iter()
.map(|val| self.circom_wrapper.ark_primefield_to_num_bigint(*val))
.collect::<Vec<BigInt>>();
inputs_map.push(("external_inputs".to_string(), external_inputs_bi));
}
// Computes witness
let witness = self
.circom_wrapper
.extract_witness(&inputs_map)
.map_err(|e| {
Error::WitnessCalculationError(format!("Failed to calculate witness: {}", e))
})?;
// Extracts the z_i1(next state) from the witness vector.
let z_i1 = witness[1..1 + self.state_len()].to_vec();
Ok(z_i1)
self.execute_custom_step_native(_i, z_i, external_inputs)
} }
fn generate_step_constraints( fn generate_step_constraints(
@ -305,4 +345,34 @@ pub mod tests {
// Disable check for now // Disable check for now
// assert!(z_i1_var.is_err()) // assert!(z_i1_var.is_err())
} }
#[test]
fn test_custom_code() {
let r1cs_path = PathBuf::from("./src/frontend/circom/test_folder/cubic_circuit.r1cs");
let wasm_path =
PathBuf::from("./src/frontend/circom/test_folder/cubic_circuit_js/cubic_circuit.wasm");
let mut circom_fcircuit = CircomFCircuit::<Fr>::new((r1cs_path, wasm_path, 1, 0)).unwrap(); // state_len:1, external_inputs_len:0
circom_fcircuit.set_custom_step_native(Rc::new(|_i, z_i, _external| {
let z = z_i[0];
Ok(vec![z * z * z + z + Fr::from(5)])
}));
// Allocates z_i1 by using step_native function.
let z_i = vec![Fr::from(3_u32)];
let wrapper_circuit = crate::frontend::tests::WrapperCircuit {
FC: circom_fcircuit.clone(),
z_i: Some(z_i.clone()),
z_i1: Some(circom_fcircuit.step_native(0, z_i.clone(), vec![]).unwrap()),
};
let cs = ConstraintSystem::<Fr>::new_ref();
wrapper_circuit.generate_constraints(cs.clone()).unwrap();
assert!(
cs.is_satisfied().unwrap(),
"Constraint system is not satisfied"
);
}
} }

Loading…
Cancel
Save