diff --git a/Cargo.toml b/Cargo.toml index 6175951..818cd1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,8 +53,7 @@ name = "groth16" harness = false [features] -default = ["wasmer/default", "circom-2", "ethereum"] +default = ["wasmer/default", "ethereum"] wasm = ["wasmer/js-default"] bench-complex-all = [] -circom-2 = [] ethereum = ["ethers-core"] diff --git a/src/witness/circom.rs b/src/witness/circom.rs index c88c1b9..ad3053d 100644 --- a/src/witness/circom.rs +++ b/src/witness/circom.rs @@ -7,164 +7,57 @@ pub struct Wasm { pub memory: Memory, } -pub trait CircomBase { - fn init(&self, store: &mut Store, sanity_check: bool) -> Result<()>; - fn func(&self, name: &str) -> &Function; - fn get_n_vars(&self, store: &mut Store) -> Result; - fn get_u32(&self, store: &mut Store, name: &str) -> Result; - // Only exists natively in Circom2, hardcoded for Circom - fn get_version(&self, store: &mut Store) -> Result; -} - -pub trait Circom1 { - fn get_ptr_witness(&self, store: &mut Store, w: u32) -> Result; - fn get_fr_len(&self, store: &mut Store) -> Result; - fn get_signal_offset32( - &self, - store: &mut Store, - p_sig_offset: u32, - component: u32, - hash_msb: u32, - hash_lsb: u32, - ) -> Result<()>; - fn set_signal( - &self, - store: &mut Store, - c_idx: u32, - component: u32, - signal: u32, - p_val: u32, - ) -> Result<()>; - fn get_ptr_raw_prime(&self, store: &mut Store) -> Result; -} - -pub trait Circom2 { - fn get_field_num_len32(&self, store: &mut Store) -> Result; - fn get_raw_prime(&self, store: &mut Store) -> Result<()>; - fn read_shared_rw_memory(&self, store: &mut Store, i: u32) -> Result; - fn write_shared_rw_memory(&self, store: &mut Store, i: u32, v: u32) -> Result<()>; - fn set_input_signal(&self, store: &mut Store, hmsb: u32, hlsb: u32, pos: u32) -> Result<()>; - fn get_witness(&self, store: &mut Store, i: u32) -> Result<()>; - fn get_witness_size(&self, store: &mut Store) -> Result; -} - -impl Circom1 for Wasm { - fn get_fr_len(&self, store: &mut Store) -> Result { - self.get_u32(store, "getFrLen") - } - - fn get_ptr_raw_prime(&self, store: &mut Store) -> Result { - self.get_u32(store, "getPRawPrime") - } - - fn get_ptr_witness(&self, store: &mut Store, w: u32) -> Result { - let func = self.func("getPWitness"); - - let res = func.call(store, &[w.into()])?; - - Ok(res[0].unwrap_i32() as u32) - } - - fn get_signal_offset32( - &self, - store: &mut Store, - p_sig_offset: u32, - component: u32, - hash_msb: u32, - hash_lsb: u32, - ) -> Result<()> { - let func = self.func("getSignalOffset32"); - func.call( - store, - &[ - p_sig_offset.into(), - component.into(), - hash_msb.into(), - hash_lsb.into(), - ], - )?; - - Ok(()) - } - - fn set_signal( - &self, - store: &mut Store, - c_idx: u32, - component: u32, - signal: u32, - p_val: u32, - ) -> Result<()> { - let func = self.func("setSignal"); - func.call( - store, - &[c_idx.into(), component.into(), signal.into(), p_val.into()], - )?; - - Ok(()) - } -} - -#[cfg(feature = "circom-2")] -impl Circom2 for Wasm { - fn get_field_num_len32(&self, store: &mut Store) -> Result { +impl Wasm { + pub(crate) fn get_field_num_len32(&self, store: &mut Store) -> Result { self.get_u32(store, "getFieldNumLen32") } - fn get_raw_prime(&self, store: &mut Store) -> Result<()> { + pub(crate) fn get_raw_prime(&self, store: &mut Store) -> Result<()> { let func = self.func("getRawPrime"); func.call(store, &[])?; Ok(()) } - fn read_shared_rw_memory(&self, store: &mut Store, i: u32) -> Result { + pub(crate) fn read_shared_rw_memory(&self, store: &mut Store, i: u32) -> Result { let func = self.func("readSharedRWMemory"); let result = func.call(store, &[i.into()])?; Ok(result[0].unwrap_i32() as u32) } - fn write_shared_rw_memory(&self, store: &mut Store, i: u32, v: u32) -> Result<()> { + pub(crate) fn write_shared_rw_memory(&self, store: &mut Store, i: u32, v: u32) -> Result<()> { let func = self.func("writeSharedRWMemory"); func.call(store, &[i.into(), v.into()])?; Ok(()) } - fn set_input_signal(&self, store: &mut Store, hmsb: u32, hlsb: u32, pos: u32) -> Result<()> { + pub(crate) fn set_input_signal( + &self, + store: &mut Store, + hmsb: u32, + hlsb: u32, + pos: u32, + ) -> Result<()> { let func = self.func("setInputSignal"); func.call(store, &[hmsb.into(), hlsb.into(), pos.into()])?; Ok(()) } - fn get_witness(&self, store: &mut Store, i: u32) -> Result<()> { + pub(crate) fn get_witness(&self, store: &mut Store, i: u32) -> Result<()> { let func = self.func("getWitness"); func.call(store, &[i.into()])?; Ok(()) } - fn get_witness_size(&self, store: &mut Store) -> Result { + pub(crate) fn get_witness_size(&self, store: &mut Store) -> Result { self.get_u32(store, "getWitnessSize") } -} -impl CircomBase for Wasm { - fn init(&self, store: &mut Store, sanity_check: bool) -> Result<()> { + pub(crate) fn init(&self, store: &mut Store, sanity_check: bool) -> Result<()> { let func = self.func("init"); func.call(store, &[Value::I32(sanity_check as i32)])?; Ok(()) } - fn get_n_vars(&self, store: &mut Store) -> Result { - self.get_u32(store, "getNVars") - } - - // Default to version 1 if it isn't explicitly defined - fn get_version(&self, store: &mut Store) -> Result { - match self.exports.get_function("getVersion") { - Ok(func) => Ok(func.call(store, &[])?[0].unwrap_i32() as u32), - Err(_) => Ok(1), - } - } - fn get_u32(&self, store: &mut Store, name: &str) -> Result { let func = &self.func(name); let result = func.call(store, &[])?; diff --git a/src/witness/mod.rs b/src/witness/mod.rs index cbb8a8a..f396bd2 100644 --- a/src/witness/mod.rs +++ b/src/witness/mod.rs @@ -5,14 +5,8 @@ mod memory; pub(super) use memory::SafeMemory; mod circom; -pub(super) use circom::CircomBase; pub use circom::Wasm; -#[cfg(feature = "circom-2")] -pub(super) use circom::Circom2; - -pub(super) use circom::Circom1; - use fnv::FnvHasher; use std::hash::Hasher; diff --git a/src/witness/witness_calculator.rs b/src/witness/witness_calculator.rs index ce73edc..c517ad3 100644 --- a/src/witness/witness_calculator.rs +++ b/src/witness/witness_calculator.rs @@ -1,23 +1,17 @@ -use super::{fnv, CircomBase, SafeMemory, Wasm}; +use super::{fnv, SafeMemory, Wasm}; use color_eyre::Result; use num_bigint::BigInt; use num_traits::Zero; use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store}; use wasmer_wasix::WasiEnv; -#[cfg(feature = "circom-2")] use num::ToPrimitive; -use super::Circom1; -#[cfg(feature = "circom-2")] -use super::Circom2; - #[derive(Debug)] pub struct WitnessCalculator { pub instance: Wasm, pub memory: Option, pub n64: u32, - pub circom_version: u32, pub prime: BigInt, } @@ -27,7 +21,6 @@ pub struct WitnessCalculator { #[error("{0}")] struct ExitCode(u32); -#[cfg(feature = "circom-2")] fn from_array32(arr: Vec) -> BigInt { let mut res = BigInt::zero(); let radix = BigInt::from(0x100000000u64); @@ -37,7 +30,6 @@ fn from_array32(arr: Vec) -> BigInt { res } -#[cfg(feature = "circom-2")] fn to_array32(s: &BigInt, size: usize) -> Vec { let mut res = vec![0; size]; let mut rem = s.clone(); @@ -95,78 +87,24 @@ impl WitnessCalculator { Ok(wasm) } - pub fn new_from_wasm(store: &mut Store, wasm: Wasm) -> Result { - let version = wasm.get_version(store).unwrap_or(1); - // Circom 2 feature flag with version 2 - #[cfg(feature = "circom-2")] - fn new_circom2( - instance: Wasm, - store: &mut Store, - version: u32, - ) -> Result { - let n32 = instance.get_field_num_len32(store)?; - instance.get_raw_prime(store)?; - let mut arr = vec![0; n32 as usize]; - for i in 0..n32 { - let res = instance.read_shared_rw_memory(store, i)?; - arr[(n32 as usize) - (i as usize) - 1] = res; - } - let prime = from_array32(arr); - - let n64 = ((prime.bits() - 1) / 64 + 1) as u32; - - Ok(WitnessCalculator { - instance, - memory: None, - n64, - circom_version: version, - prime, - }) + pub fn new_from_wasm(store: &mut Store, instance: Wasm) -> Result { + let n32 = instance.get_field_num_len32(store)?; + instance.get_raw_prime(store)?; + let mut arr = vec![0; n32 as usize]; + for i in 0..n32 { + let res = instance.read_shared_rw_memory(store, i)?; + arr[(n32 as usize) - (i as usize) - 1] = res; } + let prime = from_array32(arr); - fn new_circom1( - instance: Wasm, - store: &mut Store, - version: u32, - ) -> Result { - // Fallback to Circom 1 behavior - let n32 = (instance.get_fr_len(store)? >> 2) - 2; - let mut safe_memory = - SafeMemory::new(instance.memory.clone(), n32 as usize, BigInt::zero()); - let ptr = instance.get_ptr_raw_prime(store)?; - let prime = safe_memory.read_big(store, ptr as usize, n32 as usize)?; - - let n64 = ((prime.bits() - 1) / 64 + 1) as u32; - safe_memory.prime = prime.clone(); - - Ok(WitnessCalculator { - instance, - memory: Some(safe_memory), - n64, - circom_version: version, - prime, - }) - } - - // Three possibilities: - // a) Circom 2 feature flag enabled, WASM runtime version 2 - // b) Circom 2 feature flag enabled, WASM runtime version 1 - // c) Circom 1 default behavior - // - // Once Circom 2 support is more stable, feature flag can be removed + let n64 = ((prime.bits() - 1) / 64 + 1) as u32; - cfg_if::cfg_if! { - if #[cfg(feature = "circom-2")] { - match version { - 2 => new_circom2(wasm, store, version), - 1 => new_circom1(wasm, store, version), - - _ => panic!("Unknown Circom version") - } - } else { - new_circom1(instance, memory, version) - } - } + Ok(WitnessCalculator { + instance, + memory: None, + n64, + prime, + }) } pub fn calculate_witness)>>( @@ -177,72 +115,9 @@ impl WitnessCalculator { ) -> Result> { self.instance.init(store, sanity_check)?; - cfg_if::cfg_if! { - if #[cfg(feature = "circom-2")] { - match self.circom_version { - 2 => self.calculate_witness_circom2(store, inputs), - 1 => self.calculate_witness_circom1(store, inputs), - _ => panic!("Unknown Circom version") - } - } else { - self.calculate_witness_circom1(inputs, sanity_check) - } - } + self.calculate_witness_circom2(store, inputs) } - // Circom 1 default behavior - fn calculate_witness_circom1)>>( - &mut self, - store: &mut Store, - inputs: I, - ) -> Result> { - let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos(store)?; - let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32(store)?; - let p_fr = self.memory.as_mut().unwrap().alloc_fr(store)?; - - // allocate the inputs - for (name, values) in inputs.into_iter() { - let (msb, lsb) = fnv(&name); - - self.instance - .get_signal_offset32(store, p_sig_offset, 0, msb, lsb)?; - - let sig_offset = self - .memory - .as_ref() - .unwrap() - .read_u32(store, p_sig_offset as usize) - .unwrap() as usize; - - for (i, value) in values.into_iter().enumerate() { - self.memory - .as_mut() - .unwrap() - .write_fr(store, p_fr as usize, &value)?; - self.instance - .set_signal(store, 0, 0, (sig_offset + i) as u32, p_fr)?; - } - } - - let mut w = Vec::new(); - - let n_vars = self.instance.get_n_vars(store)?; - for i in 0..n_vars { - let ptr = self.instance.get_ptr_witness(store, i)? as usize; - let el = self.memory.as_ref().unwrap().read_fr(store, ptr)?; - w.push(el); - } - - self.memory - .as_mut() - .unwrap() - .set_free_pos(store, old_mem_free_pos)?; - - Ok(w) - } - - // Circom 2 feature flag with version 2 - #[cfg(feature = "circom-2")] fn calculate_witness_circom2)>>( &mut self, store: &mut Store, @@ -483,10 +358,6 @@ mod tests { wtns.prime.to_str_radix(16), "30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase() ); - assert_eq!( - { wtns.instance.get_n_vars(&mut store).unwrap() }, - case.n_vars - ); assert_eq!({ wtns.n64 }, case.n64); let inputs_str = std::fs::read_to_string(case.inputs_path).unwrap();