From 4d99060fce7817c56300e2dce16ec54a87ad9f66 Mon Sep 17 00:00:00 2001 From: Martin Allen <31280145+martyall@users.noreply.github.com> Date: Wed, 3 Jul 2024 19:04:42 -0700 Subject: [PATCH] Cleanly separate Circom1 and Circom2 traits (#60) --- src/witness/circom.rs | 87 +++++++++++++++---------------- src/witness/mod.rs | 2 +- src/witness/witness_calculator.rs | 58 +++++++++------------ 3 files changed, 67 insertions(+), 80 deletions(-) diff --git a/src/witness/circom.rs b/src/witness/circom.rs index f9f5aed..41e9db8 100644 --- a/src/witness/circom.rs +++ b/src/witness/circom.rs @@ -7,9 +7,15 @@ pub struct Wasm(Instance); pub trait CircomBase { fn init(&self, sanity_check: bool) -> Result<()>; fn func(&self, name: &str) -> &Function; - fn get_ptr_witness_buffer(&self) -> Result; - fn get_ptr_witness(&self, w: u32) -> Result; fn get_n_vars(&self) -> Result; + fn get_u32(&self, name: &str) -> Result; + // Only exists natively in Circom2, hardcoded for Circom + fn get_version(&self) -> Result; +} + +pub trait Circom1 { + fn get_ptr_witness(&self, w: u32) -> Result; + fn get_fr_len(&self) -> Result; fn get_signal_offset32( &self, p_sig_offset: u32, @@ -18,13 +24,6 @@ pub trait CircomBase { hash_lsb: u32, ) -> Result<()>; fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>; - fn get_u32(&self, name: &str) -> Result; - // Only exists natively in Circom2, hardcoded for Circom - fn get_version(&self) -> Result; -} - -pub trait Circom { - fn get_fr_len(&self) -> Result; fn get_ptr_raw_prime(&self) -> Result; } @@ -38,7 +37,7 @@ pub trait Circom2 { fn get_witness_size(&self) -> Result; } -impl Circom for Wasm { +impl Circom1 for Wasm { fn get_fr_len(&self) -> Result { self.get_u32("getFrLen") } @@ -46,6 +45,38 @@ impl Circom for Wasm { fn get_ptr_raw_prime(&self) -> Result { self.get_u32("getPRawPrime") } + + fn get_ptr_witness(&self, w: u32) -> Result { + let func = self.func("getPWitness"); + let res = func.call(&[w.into()])?; + + Ok(res[0].unwrap_i32() as u32) + } + + fn get_signal_offset32( + &self, + p_sig_offset: u32, + component: u32, + hash_msb: u32, + hash_lsb: u32, + ) -> Result<()> { + let func = self.func("getSignalOffset32"); + func.call(&[ + p_sig_offset.into(), + component.into(), + hash_msb.into(), + hash_lsb.into(), + ])?; + + Ok(()) + } + + fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> { + let func = self.func("setSignal"); + func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?; + + Ok(()) + } } #[cfg(feature = "circom-2")] @@ -96,46 +127,10 @@ impl CircomBase for Wasm { Ok(()) } - fn get_ptr_witness_buffer(&self) -> Result { - self.get_u32("getWitnessBuffer") - } - - fn get_ptr_witness(&self, w: u32) -> Result { - let func = self.func("getPWitness"); - let res = func.call(&[w.into()])?; - - Ok(res[0].unwrap_i32() as u32) - } - fn get_n_vars(&self) -> Result { self.get_u32("getNVars") } - fn get_signal_offset32( - &self, - p_sig_offset: u32, - component: u32, - hash_msb: u32, - hash_lsb: u32, - ) -> Result<()> { - let func = self.func("getSignalOffset32"); - func.call(&[ - p_sig_offset.into(), - component.into(), - hash_msb.into(), - hash_lsb.into(), - ])?; - - Ok(()) - } - - fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> { - let func = self.func("setSignal"); - func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?; - - Ok(()) - } - // Default to version 1 if it isn't explicitly defined fn get_version(&self) -> Result { match self.0.exports.get_function("getVersion") { diff --git a/src/witness/mod.rs b/src/witness/mod.rs index 51708aa..e8304b7 100644 --- a/src/witness/mod.rs +++ b/src/witness/mod.rs @@ -10,7 +10,7 @@ pub(super) use circom::{CircomBase, Wasm}; #[cfg(feature = "circom-2")] pub(super) use circom::Circom2; -pub(super) use circom::Circom; +pub(super) use circom::Circom1; use fnv::FnvHasher; use std::hash::Hasher; diff --git a/src/witness/witness_calculator.rs b/src/witness/witness_calculator.rs index 49582b9..a5cbe54 100644 --- a/src/witness/witness_calculator.rs +++ b/src/witness/witness_calculator.rs @@ -2,23 +2,22 @@ use super::{fnv, CircomBase, SafeMemory, Wasm}; use color_eyre::Result; use num_bigint::BigInt; use num_traits::Zero; -use std::cell::Cell; use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store}; #[cfg(feature = "circom-2")] use num::ToPrimitive; +use super::Circom1; #[cfg(feature = "circom-2")] use super::Circom2; -use super::Circom; - #[derive(Clone, Debug)] pub struct WitnessCalculator { pub instance: Wasm, - pub memory: SafeMemory, + pub memory: Option, pub n64: u32, pub circom_version: u32, + pub prime: BigInt, } // Error type to signal end of execution. @@ -92,9 +91,8 @@ impl WitnessCalculator { // Circom 2 feature flag with version 2 #[cfg(feature = "circom-2")] - fn new_circom2(instance: Wasm, memory: Memory, version: u32) -> Result { + fn new_circom2(instance: Wasm, version: u32) -> Result { let n32 = instance.get_field_num_len32()?; - let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero()); instance.get_raw_prime()?; let mut arr = vec![0; n32 as usize]; for i in 0..n32 { @@ -104,13 +102,13 @@ impl WitnessCalculator { let prime = from_array32(arr); let n64 = ((prime.bits() - 1) / 64 + 1) as u32; - safe_memory.prime = prime; Ok(WitnessCalculator { instance, - memory: safe_memory, + memory: None, n64, circom_version: version, + prime, }) } @@ -122,13 +120,14 @@ impl WitnessCalculator { let prime = safe_memory.read_big(ptr as usize, n32 as usize)?; let n64 = ((prime.bits() - 1) / 64 + 1) as u32; - safe_memory.prime = prime; + safe_memory.prime = prime.clone(); Ok(WitnessCalculator { instance, - memory: safe_memory, + memory: Some(safe_memory), n64, circom_version: version, + prime, }) } @@ -142,7 +141,7 @@ impl WitnessCalculator { cfg_if::cfg_if! { if #[cfg(feature = "circom-2")] { match version { - 2 => new_circom2(instance, memory, version), + 2 => new_circom2(instance, version), 1 => new_circom1(instance, memory, version), _ => panic!("Unknown Circom version") } @@ -180,9 +179,9 @@ impl WitnessCalculator { ) -> Result> { self.instance.init(sanity_check)?; - let old_mem_free_pos = self.memory.free_pos(); - let p_sig_offset = self.memory.alloc_u32(); - let p_fr = self.memory.alloc_fr(); + let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos(); + let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32(); + let p_fr = self.memory.as_mut().unwrap().alloc_fr(); // allocate the inputs for (name, values) in inputs.into_iter() { @@ -191,10 +190,17 @@ impl WitnessCalculator { self.instance .get_signal_offset32(p_sig_offset, 0, msb, lsb)?; - let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize; + let sig_offset = self + .memory + .as_ref() + .unwrap() + .read_u32(p_sig_offset as usize) as usize; for (i, value) in values.into_iter().enumerate() { - self.memory.write_fr(p_fr as usize, &value)?; + self.memory + .as_mut() + .unwrap() + .write_fr(p_fr as usize, &value)?; self.instance .set_signal(0, 0, (sig_offset + i) as u32, p_fr)?; } @@ -205,11 +211,11 @@ impl WitnessCalculator { let n_vars = self.instance.get_n_vars()?; for i in 0..n_vars { let ptr = self.instance.get_ptr_witness(i)? as usize; - let el = self.memory.read_fr(ptr)?; + let el = self.memory.as_ref().unwrap().read_fr(ptr)?; w.push(el); } - self.memory.set_free_pos(old_mem_free_pos); + self.memory.as_mut().unwrap().set_free_pos(old_mem_free_pos); Ok(w) } @@ -283,20 +289,6 @@ impl WitnessCalculator { Ok(witness) } - - pub fn get_witness_buffer(&self) -> Result> { - let ptr = self.instance.get_ptr_witness_buffer()? as usize; - - let view = self.memory.memory.view::(); - - let len = self.instance.get_n_vars()? * self.n64 * 8; - let arr = view[ptr..ptr + len as usize] - .iter() - .map(Cell::get) - .collect::>(); - - Ok(arr) - } } // callback hooks for debugging @@ -463,7 +455,7 @@ mod tests { fn run_test(case: TestCase) { let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap(); assert_eq!( - wtns.memory.prime.to_str_radix(16), + wtns.prime.to_str_radix(16), "30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase() ); assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);