Browse Source

Ensure Circom 1 tests pass with experimental Circom 2 support (#18)

* All tests pass under circom-2 feature flag

- Check for version in WASM, default to version 1
- Include Circom1 when Circom 2 feature flag is enabled

Currently a lot of code duplication. Once Circom-2 is more stable and
proven to work in the wild, feature flag can be removed.

* Separate Circom 1 and Circom2 witness calculation

* Cleanup

* WitnessCalculator helpers for Circom 1 and 2

Also make helper fn private

* Move comment

* Fix expression return

* cargo fmt

* Add cargo test circom-2 to ci
pull/3/head
oskarth 2 years ago
committed by GitHub
parent
commit
1a383b6260
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 196 additions and 73 deletions
  1. +5
    -0
      .github/workflows/ci.yml
  2. +47
    -0
      Cargo.lock
  3. +10
    -6
      src/witness/circom.rs
  4. +0
    -1
      src/witness/mod.rs
  5. +134
    -66
      src/witness/witness_calculator.rs

+ 5
- 0
.github/workflows/ci.yml

@ -44,6 +44,11 @@ jobs:
export PATH=$HOME/bin:$PATH export PATH=$HOME/bin:$PATH
cargo test cargo test
- name: cargo test circom 2 feature flag
run: |
export PATH=$HOME/bin:$PATH
cargo test --features circom-2
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

+ 47
- 0
Cargo.lock

@ -97,6 +97,7 @@ dependencies = [
"fnv", "fnv",
"hex", "hex",
"hex-literal", "hex-literal",
"num",
"num-bigint", "num-bigint",
"num-traits", "num-traits",
"serde_json", "serde_json",
@ -1950,6 +1951,20 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "num"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]] [[package]]
name = "num-bigint" name = "num-bigint"
version = "0.4.0" version = "0.4.0"
@ -1962,6 +1977,15 @@ dependencies = [
"rand", "rand",
] ]
[[package]]
name = "num-complex"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-integer" name = "num-integer"
version = "0.1.44" version = "0.1.44"
@ -1972,6 +1996,29 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "num-iter"
version = "0.1.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2021c8337a54d21aca0d59a92577a029af9431cb59b909b03252b9c164fad59"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d41702bd167c2df5520b384281bc111a4b5efcf7fbc4c9c222c815b07e0a6a6a"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.14" version = "0.2.14"

+ 10
- 6
src/witness/circom.rs

@ -19,6 +19,8 @@ pub trait CircomBase {
) -> Result<()>; ) -> Result<()>;
fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()>; fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()>;
fn get_i32(&self, name: &str) -> Result<i32>; fn get_i32(&self, name: &str) -> Result<i32>;
// Only exists natively in Circom2, hardcoded for Circom
fn get_version(&self) -> Result<i32>;
} }
pub trait Circom { pub trait Circom {
@ -27,7 +29,6 @@ pub trait Circom {
} }
pub trait Circom2 { pub trait Circom2 {
fn get_version(&self) -> Result<i32>;
fn get_field_num_len32(&self) -> Result<i32>; fn get_field_num_len32(&self) -> Result<i32>;
fn get_raw_prime(&self) -> Result<()>; fn get_raw_prime(&self) -> Result<()>;
fn read_shared_rw_memory(&self, i: i32) -> Result<i32>; fn read_shared_rw_memory(&self, i: i32) -> Result<i32>;
@ -37,7 +38,6 @@ pub trait Circom2 {
fn get_witness_size(&self) -> Result<i32>; fn get_witness_size(&self) -> Result<i32>;
} }
#[cfg(not(feature = "circom-2"))]
impl Circom for Wasm { impl Circom for Wasm {
fn get_fr_len(&self) -> Result<i32> { fn get_fr_len(&self) -> Result<i32> {
self.get_i32("getFrLen") self.get_i32("getFrLen")
@ -50,10 +50,6 @@ impl Circom for Wasm {
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
impl Circom2 for Wasm { impl Circom2 for Wasm {
fn get_version(&self) -> Result<i32> {
self.get_i32("getVersion")
}
fn get_field_num_len32(&self) -> Result<i32> { fn get_field_num_len32(&self) -> Result<i32> {
self.get_i32("getFieldNumLen32") self.get_i32("getFieldNumLen32")
} }
@ -142,6 +138,14 @@ impl CircomBase for Wasm {
Ok(()) Ok(())
} }
// Default to version 1 if it isn't explicitly defined
fn get_version(&self) -> Result<i32> {
match self.0.exports.get_function("getVersion") {
Ok(func) => Ok(func.call(&[])?[0].unwrap_i32()),
Err(_) => Ok(1),
}
}
fn get_i32(&self, name: &str) -> Result<i32> { fn get_i32(&self, name: &str) -> Result<i32> {
let func = self.func(name); let func = self.func(name);
let result = func.call(&[])?; let result = func.call(&[])?;

+ 0
- 1
src/witness/mod.rs

@ -10,7 +10,6 @@ pub(super) use circom::{CircomBase, Wasm};
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
pub(super) use circom::Circom2; pub(super) use circom::Circom2;
#[cfg(not(feature = "circom-2"))]
pub(super) use circom::Circom; pub(super) use circom::Circom;
use fnv::FnvHasher; use fnv::FnvHasher;

+ 134
- 66
src/witness/witness_calculator.rs

@ -11,7 +11,6 @@ use num::ToPrimitive;
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
use super::Circom2; use super::Circom2;
#[cfg(not(feature = "circom-2"))]
use super::Circom; use super::Circom;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -19,6 +18,7 @@ pub struct WitnessCalculator {
pub instance: Wasm, pub instance: Wasm,
pub memory: SafeMemory, pub memory: SafeMemory,
pub n64: i32, pub n64: i32,
pub circom_version: i32,
} }
// Error type to signal end of execution. // Error type to signal end of execution.
@ -77,36 +77,73 @@ impl WitnessCalculator {
}; };
let instance = Wasm::new(Instance::new(&module, &import_object)?); let instance = Wasm::new(Instance::new(&module, &import_object)?);
let version;
match instance.get_version() {
Ok(v) => version = v,
Err(_) => version = 1,
}
// Circom 2 feature flag with version 2
#[cfg(feature = "circom-2")]
fn new_circom2(instance: Wasm, memory: Memory, version: i32) -> Result<WitnessCalculator> {
let n32 = instance.get_field_num_len32()?;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
instance.get_raw_prime()?;
let mut arr = vec![0; n32 as usize];
for i in 0..n32 {
let res = instance.read_shared_rw_memory(i)?;
arr[(n32 as usize) - (i as usize) - 1] = res;
}
let prime = from_array32(arr);
let n64 = ((prime.bits() - 1) / 64 + 1) as i32;
safe_memory.prime = prime;
Ok(WitnessCalculator {
instance,
memory: safe_memory,
n64,
circom_version: version,
})
}
fn new_circom1(instance: Wasm, memory: Memory, version: i32) -> Result<WitnessCalculator> {
// Fallback to Circom 1 behavior
let n32 = (instance.get_fr_len()? >> 2) - 2;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let ptr = instance.get_ptr_raw_prime()?;
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
let n64 = ((prime.bits() - 1) / 64 + 1) as i32;
safe_memory.prime = prime;
Ok(WitnessCalculator {
instance,
memory: safe_memory,
n64,
circom_version: version,
})
}
// Three possibilities:
// a) Circom 2 feature flag enabled, WASM runtime version 2
// b) Circom 2 feature flag enabled, WASM runtime version 1
// c) Circom 1 default behavior
//
// Once Circom 2 support is more stable, feature flag can be removed
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] { if #[cfg(feature = "circom-2")] {
//let version = instance.get_version()?;
let n32 = instance.get_field_num_len32()?;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
instance.get_raw_prime()?;
let mut arr = vec![0; n32 as usize];
for i in 0..n32 {
let res = instance.read_shared_rw_memory(i)?;
arr[(n32 as usize) - (i as usize) - 1] = res;
match version {
2 => new_circom2(instance, memory, version),
1 => new_circom1(instance, memory, version),
_ => panic!("Unknown Circom version")
} }
let prime = from_array32(arr);
} else { } else {
// Fallback to Circom 1 behavior
//version = 1;
let n32 = (instance.get_fr_len()? >> 2) - 2;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let ptr = instance.get_ptr_raw_prime()?;
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
new_circom1(instance, memory, version)
} }
} }
let n64 = ((prime.bits() - 1) / 64 + 1) as i32;
safe_memory.prime = prime;
Ok(WitnessCalculator {
instance,
memory: safe_memory,
n64,
})
} }
pub fn calculate_witness<I: IntoIterator<Item = (String, Vec<BigInt>)>>( pub fn calculate_witness<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
@ -118,66 +155,97 @@ impl WitnessCalculator {
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] { if #[cfg(feature = "circom-2")] {
let n32 = self.instance.get_field_num_len32()?;
match self.circom_version {
2 => self.calculate_witness_circom2(inputs, sanity_check),
1 => self.calculate_witness_circom1(inputs, sanity_check),
_ => panic!("Unknown Circom version")
}
} else { } else {
let old_mem_free_pos = self.memory.free_pos();
let p_sig_offset = self.memory.alloc_u32();
let p_fr = self.memory.alloc_fr();
self.calculate_witness_circom1(inputs, sanity_check)
} }
} }
}
// Circom 1 default behavior
fn calculate_witness_circom1<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
&mut self,
inputs: I,
sanity_check: bool,
) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?;
let old_mem_free_pos = self.memory.free_pos();
let p_sig_offset = self.memory.alloc_u32();
let p_fr = self.memory.alloc_fr();
// allocate the inputs // allocate the inputs
for (name, values) in inputs.into_iter() { for (name, values) in inputs.into_iter() {
let (msb, lsb) = fnv(&name); let (msb, lsb) = fnv(&name);
cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] {
for (i, value) in values.into_iter().enumerate() {
let f_arr = to_array32(&value, n32 as usize);
for j in 0..n32 {
self.instance.write_shared_rw_memory(j as i32, f_arr[(n32 as usize) - 1 - (j as usize)])?;
}
self.instance.set_input_signal(msb as i32, lsb as i32, i as i32)?;
}
} else {
self.instance
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
self.instance
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
for (i, value) in values.into_iter().enumerate() {
self.memory.write_fr(p_fr as usize, &value)?;
self.instance
.set_signal(0, 0, (sig_offset + i) as i32, p_fr as i32)?;
}
}
for (i, value) in values.into_iter().enumerate() {
self.memory.write_fr(p_fr as usize, &value)?;
self.instance
.set_signal(0, 0, (sig_offset + i) as i32, p_fr as i32)?;
} }
} }
let mut w = Vec::new(); let mut w = Vec::new();
cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] {
let witness_size = self.instance.get_witness_size()?;
for i in 0..witness_size {
self.instance.get_witness(i)?;
let mut arr = vec![0; n32 as usize];
for j in 0..n32 {
arr[(n32 as usize) - 1- (j as usize)] = self.instance.read_shared_rw_memory(j)?;
}
w.push(from_array32(arr));
}
let n_vars = self.instance.get_n_vars()?;
for i in 0..n_vars {
let ptr = self.instance.get_ptr_witness(i)? as usize;
let el = self.memory.read_fr(ptr)?;
w.push(el);
}
} else {
let n_vars = self.instance.get_n_vars()?;
for i in 0..n_vars {
let ptr = self.instance.get_ptr_witness(i)? as usize;
let el = self.memory.read_fr(ptr)?;
w.push(el);
self.memory.set_free_pos(old_mem_free_pos);
Ok(w)
}
// Circom 2 feature flag with version 2
#[cfg(feature = "circom-2")]
fn calculate_witness_circom2<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
&mut self,
inputs: I,
sanity_check: bool,
) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?;
let n32 = self.instance.get_field_num_len32()?;
// allocate the inputs
for (name, values) in inputs.into_iter() {
let (msb, lsb) = fnv(&name);
for (i, value) in values.into_iter().enumerate() {
let f_arr = to_array32(&value, n32 as usize);
for j in 0..n32 {
self.instance.write_shared_rw_memory(
j as i32,
f_arr[(n32 as usize) - 1 - (j as usize)],
)?;
} }
self.instance
.set_input_signal(msb as i32, lsb as i32, i as i32)?;
}
}
let mut w = Vec::new();
self.memory.set_free_pos(old_mem_free_pos);
let witness_size = self.instance.get_witness_size()?;
for i in 0..witness_size {
self.instance.get_witness(i)?;
let mut arr = vec![0; n32 as usize];
for j in 0..n32 {
arr[(n32 as usize) - 1 - (j as usize)] = self.instance.read_shared_rw_memory(j)?;
} }
w.push(from_array32(arr));
} }
Ok(w) Ok(w)

Loading…
Cancel
Save