Browse Source

Initial Circom 2 support (#10)

* Import circom-2 test vectors

* Add failing test under feature flag

* Add exceptionHandler

* Add showSharedRWMemory

* Add getFieldNumLen32 and disable getFrLen

* Add getVersion

Also print version, n32

* Add getRawPrime

- Disable getPtrRawPrime
- Write as conditional cfg code blocks

* Refactor cfg code blocks

* Add readSharedRWMemory and get prime from WASM mem

- Add fromArray32 convenience function

* WIP: Debug R1CSfile header

field_size in header is 1, not 32 as expected

Don't see anything recently changed here:
https://github.com/iden3/r1csfile/blob/master/src/r1csfile.js (used by snarkjs)

But this seems new: 0149dc0643/constraint_writers/src/r1cs_writer.rs

* Add CircomVersion struct to Wasm

* XXX: Enum test

* Trait version

* Move traits to Circom, CircomBase, Circom2

* Simplify Wasm struct and remove version

* Feature gate Circom1/Circom2 traits

* Use cfg_if for witness calculation

Make normal dependency

* Fix visibilty for both test paths

* Remove println

Can introduce tracing separately

* refactor

* Make clippy happy with imports, unused variables
pull/3/head
oskarth 2 years ago
committed by GitHub
parent
commit
64e0ee9546
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 175 additions and 27 deletions
  1. +3
    -1
      Cargo.toml
  2. +74
    -17
      src/witness/circom.rs
  3. +7
    -1
      src/witness/mod.rs
  4. +59
    -8
      src/witness/witness_calculator.rs
  5. BIN
      test-vectors/circom2_multiplier2.r1cs
  6. BIN
      test-vectors/circom2_multiplier2.wasm
  7. +32
    -0
      tests/groth16.rs

+ 3
- 1
Cargo.toml

@ -32,12 +32,13 @@ thiserror = "1.0.26"
color-eyre = "0.5" color-eyre = "0.5"
criterion = "0.3.4" criterion = "0.3.4"
cfg-if = "1.0"
[dev-dependencies] [dev-dependencies]
hex-literal = "0.2.1" hex-literal = "0.2.1"
tokio = { version = "1.7.1", features = ["macros"] } tokio = { version = "1.7.1", features = ["macros"] }
serde_json = "1.0.64" serde_json = "1.0.64"
ethers = { git = "https://github.com/gakonst/ethers-rs", features = ["abigen"] } ethers = { git = "https://github.com/gakonst/ethers-rs", features = ["abigen"] }
cfg-if = "1.0"
[[bench]] [[bench]]
name = "groth16" name = "groth16"
@ -45,3 +46,4 @@ harness = false
[features] [features]
bench-complex-all = [] bench-complex-all = []
circom-2 = []

+ 74
- 17
src/witness/circom.rs

@ -4,41 +4,92 @@ use wasmer::{Function, Instance, Value};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Wasm(Instance); pub struct Wasm(Instance);
impl Wasm {
pub fn new(instance: Instance) -> Self {
Self(instance)
}
pub trait CircomBase {
fn init(&self, sanity_check: bool) -> Result<()>;
fn func(&self, name: &str) -> &Function;
fn get_ptr_witness_buffer(&self) -> Result<i32>;
fn get_ptr_witness(&self, w: i32) -> Result<i32>;
fn get_n_vars(&self) -> Result<i32>;
fn get_signal_offset32(
&self,
p_sig_offset: u32,
component: u32,
hash_msb: u32,
hash_lsb: u32,
) -> Result<()>;
fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()>;
fn get_i32(&self, name: &str) -> Result<i32>;
}
pub fn init(&self, sanity_check: bool) -> Result<()> {
let func = self.func("init");
func.call(&[Value::I32(sanity_check as i32)])?;
Ok(())
}
pub trait Circom {
fn get_fr_len(&self) -> Result<i32>;
fn get_ptr_raw_prime(&self) -> Result<i32>;
}
pub trait Circom2 {
fn get_version(&self) -> Result<i32>;
fn get_field_num_len32(&self) -> Result<i32>;
fn get_raw_prime(&self) -> Result<()>;
fn read_shared_rw_memory(&self, i: i32) -> Result<i32>;
}
pub fn get_fr_len(&self) -> Result<i32> {
#[cfg(not(feature = "circom-2"))]
impl Circom for Wasm {
fn get_fr_len(&self) -> Result<i32> {
self.get_i32("getFrLen") self.get_i32("getFrLen")
} }
pub fn get_ptr_raw_prime(&self) -> Result<i32> {
fn get_ptr_raw_prime(&self) -> Result<i32> {
self.get_i32("getPRawPrime") self.get_i32("getPRawPrime")
} }
}
pub fn get_n_vars(&self) -> Result<i32> {
self.get_i32("getNVars")
#[cfg(feature = "circom-2")]
impl Circom2 for Wasm {
fn get_version(&self) -> Result<i32> {
self.get_i32("getVersion")
}
fn get_field_num_len32(&self) -> Result<i32> {
self.get_i32("getFieldNumLen32")
} }
pub fn get_ptr_witness_buffer(&self) -> Result<i32> {
fn get_raw_prime(&self) -> Result<()> {
let func = self.func("getRawPrime");
let _result = func.call(&[])?;
Ok(())
}
fn read_shared_rw_memory(&self, i: i32) -> Result<i32> {
let func = self.func("readSharedRWMemory");
let result = func.call(&[i.into()])?;
Ok(result[0].unwrap_i32())
}
}
impl CircomBase for Wasm {
fn init(&self, sanity_check: bool) -> Result<()> {
let func = self.func("init");
func.call(&[Value::I32(sanity_check as i32)])?;
Ok(())
}
fn get_ptr_witness_buffer(&self) -> Result<i32> {
self.get_i32("getWitnessBuffer") self.get_i32("getWitnessBuffer")
} }
pub fn get_ptr_witness(&self, w: i32) -> Result<i32> {
fn get_ptr_witness(&self, w: i32) -> Result<i32> {
let func = self.func("getPWitness"); let func = self.func("getPWitness");
let res = func.call(&[w.into()])?; let res = func.call(&[w.into()])?;
Ok(res[0].unwrap_i32()) Ok(res[0].unwrap_i32())
} }
pub fn get_signal_offset32(
fn get_n_vars(&self) -> Result<i32> {
self.get_i32("getNVars")
}
fn get_signal_offset32(
&self, &self,
p_sig_offset: u32, p_sig_offset: u32,
component: u32, component: u32,
@ -56,7 +107,7 @@ impl Wasm {
Ok(()) Ok(())
} }
pub fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()> {
fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()> {
let func = self.func("setSignal"); let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?; func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
@ -76,3 +127,9 @@ impl Wasm {
.unwrap_or_else(|_| panic!("function {} not found", name)) .unwrap_or_else(|_| panic!("function {} not found", name))
} }
} }
impl Wasm {
pub fn new(instance: Instance) -> Self {
Self(instance)
}
}

+ 7
- 1
src/witness/mod.rs

@ -5,7 +5,13 @@ mod memory;
pub(super) use memory::SafeMemory; pub(super) use memory::SafeMemory;
mod circom; mod circom;
pub(super) use circom::Wasm;
pub(super) use circom::{CircomBase, Wasm};
#[cfg(feature = "circom-2")]
pub(super) use circom::Circom2;
#[cfg(not(feature = "circom-2"))]
pub(super) use circom::Circom;
use fnv::FnvHasher; use fnv::FnvHasher;
use std::hash::Hasher; use std::hash::Hasher;

+ 59
- 8
src/witness/witness_calculator.rs

@ -1,10 +1,15 @@
use super::{fnv, CircomBase, SafeMemory, Wasm};
use color_eyre::Result; use color_eyre::Result;
use num_bigint::BigInt; use num_bigint::BigInt;
use num_traits::Zero; use num_traits::Zero;
use std::cell::Cell; use std::cell::Cell;
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store}; use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};
use super::{fnv, SafeMemory, Wasm};
#[cfg(feature = "circom-2")]
use super::Circom2;
#[cfg(not(feature = "circom-2"))]
use super::Circom;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct WitnessCalculator { pub struct WitnessCalculator {
@ -19,6 +24,16 @@ pub struct WitnessCalculator {
#[error("{0}")] #[error("{0}")]
struct ExitCode(u32); struct ExitCode(u32);
#[cfg(feature = "circom-2")]
fn from_array32(arr: Vec<i32>) -> BigInt {
let mut res = BigInt::zero();
let radix = BigInt::from(0x100000000u64);
for &val in arr.iter() {
res = res * &radix + BigInt::from(val);
}
res
}
impl WitnessCalculator { impl WitnessCalculator {
pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self> { pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self> {
let store = Store::default(); let store = Store::default();
@ -38,22 +53,44 @@ impl WitnessCalculator {
"logFinishComponent" => runtime::log_component(&store), "logFinishComponent" => runtime::log_component(&store),
"logStartComponent" => runtime::log_component(&store), "logStartComponent" => runtime::log_component(&store),
"log" => runtime::log_component(&store), "log" => runtime::log_component(&store),
"exceptionHandler" => runtime::exception_handler(&store),
"showSharedRWMemory" => runtime::show_memory(&store),
} }
}; };
let instance = Wasm::new(Instance::new(&module, &import_object)?); let instance = Wasm::new(Instance::new(&module, &import_object)?);
let n32 = (instance.get_fr_len()? >> 2) - 2;
let mut memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let n32;
let prime: BigInt;
let mut safe_memory: SafeMemory;
cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] {
//let version = instance.get_version()?;
n32 = instance.get_field_num_len32()?;
safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let _res = instance.get_raw_prime()?;
let mut arr = vec![0; n32 as usize];
for i in 0..n32 {
let res = instance.read_shared_rw_memory(i)?;
arr[(n32 as usize) - (i as usize) - 1] = res;
}
prime = from_array32(arr);
} else {
// Fallback to Circom 1 behavior
//version = 1;
n32 = (instance.get_fr_len()? >> 2) - 2;
safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let ptr = instance.get_ptr_raw_prime()?;
prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
}
}
let ptr = instance.get_ptr_raw_prime()?;
let prime = memory.read_big(ptr as usize, n32 as usize)?;
let n64 = ((prime.bits() - 1) / 64 + 1) as i32; let n64 = ((prime.bits() - 1) / 64 + 1) as i32;
memory.prime = prime;
safe_memory.prime = prime;
Ok(WitnessCalculator { Ok(WitnessCalculator {
instance, instance,
memory,
memory: safe_memory,
n64, n64,
}) })
} }
@ -162,6 +199,20 @@ mod runtime {
Function::new_native(store, func) Function::new_native(store, func)
} }
// Circom 2.0
pub fn exception_handler(store: &Store) -> Function {
#[allow(unused)]
fn func(a: i32) {}
Function::new_native(store, func)
}
// Circom 2.0
pub fn show_memory(store: &Store) -> Function {
#[allow(unused)]
fn func() {}
Function::new_native(store, func)
}
pub fn log_signal(store: &Store) -> Function { pub fn log_signal(store: &Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func(a: i32, b: i32) {} fn func(a: i32, b: i32) {}

BIN
test-vectors/circom2_multiplier2.r1cs


BIN
test-vectors/circom2_multiplier2.wasm


+ 32
- 0
tests/groth16.rs

@ -58,3 +58,35 @@ fn groth16_proof_wrong_input() {
builder.build().unwrap_err(); builder.build().unwrap_err();
} }
#[test]
#[cfg(feature = "circom-2")]
fn groth16_proof_circom2() -> Result<()> {
let cfg = CircomConfig::<Bn254>::new(
"./test-vectors/circom2_multiplier2.wasm",
"./test-vectors/circom2_multiplier2.r1cs",
)?;
let mut builder = CircomBuilder::new(cfg);
builder.push_input("a", 3);
builder.push_input("b", 11);
// create an empty instance for setting it up
let circom = builder.setup();
let mut rng = thread_rng();
let params = generate_random_parameters::<Bn254, _, _>(circom, &mut rng)?;
let circom = builder.build()?;
let inputs = circom.get_public_inputs().unwrap();
let proof = prove(circom, &params, &mut rng)?;
let pvk = prepare_verifying_key(&params.vk);
let verified = verify_proof(&pvk, &proof, &inputs)?;
assert!(verified);
Ok(())
}

Loading…
Cancel
Save