Browse Source

Cleanly separate Circom1 and Circom2 traits (#60)

pull/4/head
Martin Allen 4 months ago
committed by GitHub
parent
commit
4d99060fce
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 80 deletions
  1. +41
    -46
      src/witness/circom.rs
  2. +1
    -1
      src/witness/mod.rs
  3. +25
    -33
      src/witness/witness_calculator.rs

+ 41
- 46
src/witness/circom.rs

@ -7,9 +7,15 @@ pub struct Wasm(Instance);
pub trait CircomBase { pub trait CircomBase {
fn init(&self, sanity_check: bool) -> Result<()>; fn init(&self, sanity_check: bool) -> Result<()>;
fn func(&self, name: &str) -> &Function; fn func(&self, name: &str) -> &Function;
fn get_ptr_witness_buffer(&self) -> Result<u32>;
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
fn get_n_vars(&self) -> Result<u32>; fn get_n_vars(&self) -> Result<u32>;
fn get_u32(&self, name: &str) -> Result<u32>;
// Only exists natively in Circom2, hardcoded for Circom
fn get_version(&self) -> Result<u32>;
}
pub trait Circom1 {
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
fn get_fr_len(&self) -> Result<u32>;
fn get_signal_offset32( fn get_signal_offset32(
&self, &self,
p_sig_offset: u32, p_sig_offset: u32,
@ -18,13 +24,6 @@ pub trait CircomBase {
hash_lsb: u32, hash_lsb: u32,
) -> Result<()>; ) -> Result<()>;
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>; fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>;
fn get_u32(&self, name: &str) -> Result<u32>;
// Only exists natively in Circom2, hardcoded for Circom
fn get_version(&self) -> Result<u32>;
}
pub trait Circom {
fn get_fr_len(&self) -> Result<u32>;
fn get_ptr_raw_prime(&self) -> Result<u32>; fn get_ptr_raw_prime(&self) -> Result<u32>;
} }
@ -38,7 +37,7 @@ pub trait Circom2 {
fn get_witness_size(&self) -> Result<u32>; fn get_witness_size(&self) -> Result<u32>;
} }
impl Circom for Wasm {
impl Circom1 for Wasm {
fn get_fr_len(&self) -> Result<u32> { fn get_fr_len(&self) -> Result<u32> {
self.get_u32("getFrLen") self.get_u32("getFrLen")
} }
@ -46,6 +45,38 @@ impl Circom for Wasm {
fn get_ptr_raw_prime(&self) -> Result<u32> { fn get_ptr_raw_prime(&self) -> Result<u32> {
self.get_u32("getPRawPrime") self.get_u32("getPRawPrime")
} }
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
let func = self.func("getPWitness");
let res = func.call(&[w.into()])?;
Ok(res[0].unwrap_i32() as u32)
}
fn get_signal_offset32(
&self,
p_sig_offset: u32,
component: u32,
hash_msb: u32,
hash_lsb: u32,
) -> Result<()> {
let func = self.func("getSignalOffset32");
func.call(&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
])?;
Ok(())
}
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
Ok(())
}
} }
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
@ -96,46 +127,10 @@ impl CircomBase for Wasm {
Ok(()) Ok(())
} }
fn get_ptr_witness_buffer(&self) -> Result<u32> {
self.get_u32("getWitnessBuffer")
}
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
let func = self.func("getPWitness");
let res = func.call(&[w.into()])?;
Ok(res[0].unwrap_i32() as u32)
}
fn get_n_vars(&self) -> Result<u32> { fn get_n_vars(&self) -> Result<u32> {
self.get_u32("getNVars") self.get_u32("getNVars")
} }
fn get_signal_offset32(
&self,
p_sig_offset: u32,
component: u32,
hash_msb: u32,
hash_lsb: u32,
) -> Result<()> {
let func = self.func("getSignalOffset32");
func.call(&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
])?;
Ok(())
}
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
Ok(())
}
// Default to version 1 if it isn't explicitly defined // Default to version 1 if it isn't explicitly defined
fn get_version(&self) -> Result<u32> { fn get_version(&self) -> Result<u32> {
match self.0.exports.get_function("getVersion") { match self.0.exports.get_function("getVersion") {

+ 1
- 1
src/witness/mod.rs

@ -10,7 +10,7 @@ pub(super) use circom::{CircomBase, Wasm};
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
pub(super) use circom::Circom2; pub(super) use circom::Circom2;
pub(super) use circom::Circom;
pub(super) use circom::Circom1;
use fnv::FnvHasher; use fnv::FnvHasher;
use std::hash::Hasher; use std::hash::Hasher;

+ 25
- 33
src/witness/witness_calculator.rs

@ -2,23 +2,22 @@ use super::{fnv, CircomBase, SafeMemory, Wasm};
use color_eyre::Result; use color_eyre::Result;
use num_bigint::BigInt; use num_bigint::BigInt;
use num_traits::Zero; use num_traits::Zero;
use std::cell::Cell;
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store}; use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
use num::ToPrimitive; use num::ToPrimitive;
use super::Circom1;
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
use super::Circom2; use super::Circom2;
use super::Circom;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct WitnessCalculator { pub struct WitnessCalculator {
pub instance: Wasm, pub instance: Wasm,
pub memory: SafeMemory,
pub memory: Option<SafeMemory>,
pub n64: u32, pub n64: u32,
pub circom_version: u32, pub circom_version: u32,
pub prime: BigInt,
} }
// Error type to signal end of execution. // Error type to signal end of execution.
@ -92,9 +91,8 @@ impl WitnessCalculator {
// Circom 2 feature flag with version 2 // Circom 2 feature flag with version 2
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
fn new_circom2(instance: Wasm, memory: Memory, version: u32) -> Result<WitnessCalculator> {
fn new_circom2(instance: Wasm, version: u32) -> Result<WitnessCalculator> {
let n32 = instance.get_field_num_len32()?; let n32 = instance.get_field_num_len32()?;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
instance.get_raw_prime()?; instance.get_raw_prime()?;
let mut arr = vec![0; n32 as usize]; let mut arr = vec![0; n32 as usize];
for i in 0..n32 { for i in 0..n32 {
@ -104,13 +102,13 @@ impl WitnessCalculator {
let prime = from_array32(arr); let prime = from_array32(arr);
let n64 = ((prime.bits() - 1) / 64 + 1) as u32; let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
safe_memory.prime = prime;
Ok(WitnessCalculator { Ok(WitnessCalculator {
instance, instance,
memory: safe_memory,
memory: None,
n64, n64,
circom_version: version, circom_version: version,
prime,
}) })
} }
@ -122,13 +120,14 @@ impl WitnessCalculator {
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?; let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
let n64 = ((prime.bits() - 1) / 64 + 1) as u32; let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
safe_memory.prime = prime;
safe_memory.prime = prime.clone();
Ok(WitnessCalculator { Ok(WitnessCalculator {
instance, instance,
memory: safe_memory,
memory: Some(safe_memory),
n64, n64,
circom_version: version, circom_version: version,
prime,
}) })
} }
@ -142,7 +141,7 @@ impl WitnessCalculator {
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] { if #[cfg(feature = "circom-2")] {
match version { match version {
2 => new_circom2(instance, memory, version),
2 => new_circom2(instance, version),
1 => new_circom1(instance, memory, version), 1 => new_circom1(instance, memory, version),
_ => panic!("Unknown Circom version") _ => panic!("Unknown Circom version")
} }
@ -180,9 +179,9 @@ impl WitnessCalculator {
) -> Result<Vec<BigInt>> { ) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?; self.instance.init(sanity_check)?;
let old_mem_free_pos = self.memory.free_pos();
let p_sig_offset = self.memory.alloc_u32();
let p_fr = self.memory.alloc_fr();
let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos();
let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32();
let p_fr = self.memory.as_mut().unwrap().alloc_fr();
// allocate the inputs // allocate the inputs
for (name, values) in inputs.into_iter() { for (name, values) in inputs.into_iter() {
@ -191,10 +190,17 @@ impl WitnessCalculator {
self.instance self.instance
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?; .get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
let sig_offset = self
.memory
.as_ref()
.unwrap()
.read_u32(p_sig_offset as usize) as usize;
for (i, value) in values.into_iter().enumerate() { for (i, value) in values.into_iter().enumerate() {
self.memory.write_fr(p_fr as usize, &value)?;
self.memory
.as_mut()
.unwrap()
.write_fr(p_fr as usize, &value)?;
self.instance self.instance
.set_signal(0, 0, (sig_offset + i) as u32, p_fr)?; .set_signal(0, 0, (sig_offset + i) as u32, p_fr)?;
} }
@ -205,11 +211,11 @@ impl WitnessCalculator {
let n_vars = self.instance.get_n_vars()?; let n_vars = self.instance.get_n_vars()?;
for i in 0..n_vars { for i in 0..n_vars {
let ptr = self.instance.get_ptr_witness(i)? as usize; let ptr = self.instance.get_ptr_witness(i)? as usize;
let el = self.memory.read_fr(ptr)?;
let el = self.memory.as_ref().unwrap().read_fr(ptr)?;
w.push(el); w.push(el);
} }
self.memory.set_free_pos(old_mem_free_pos);
self.memory.as_mut().unwrap().set_free_pos(old_mem_free_pos);
Ok(w) Ok(w)
} }
@ -283,20 +289,6 @@ impl WitnessCalculator {
Ok(witness) Ok(witness)
} }
pub fn get_witness_buffer(&self) -> Result<Vec<u8>> {
let ptr = self.instance.get_ptr_witness_buffer()? as usize;
let view = self.memory.memory.view::<u8>();
let len = self.instance.get_n_vars()? * self.n64 * 8;
let arr = view[ptr..ptr + len as usize]
.iter()
.map(Cell::get)
.collect::<Vec<_>>();
Ok(arr)
}
} }
// callback hooks for debugging // callback hooks for debugging
@ -463,7 +455,7 @@ mod tests {
fn run_test(case: TestCase) { fn run_test(case: TestCase) {
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap(); let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
assert_eq!( assert_eq!(
wtns.memory.prime.to_str_radix(16),
wtns.prime.to_str_radix(16),
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase() "30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
); );
assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars); assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);

Loading…
Cancel
Save