Browse Source

Upgrade to wasmer 4.3 (#64)

pull/4/head
Martin Allen 4 months ago
committed by GitHub
parent
commit
967add46da
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
12 changed files with 2306 additions and 1209 deletions
  1. +2
    -2
      .github/workflows/ci.yml
  2. +1959
    -954
      Cargo.lock
  3. +2
    -1
      Cargo.toml
  4. +10
    -6
      benches/groth16.rs
  5. +1
    -1
      rust-toolchain.toml
  6. +30
    -9
      src/circom/builder.rs
  7. +2
    -2
      src/circom/circuit.rs
  8. +85
    -60
      src/witness/circom.rs
  9. +93
    -81
      src/witness/memory.rs
  10. +106
    -78
      src/witness/witness_calculator.rs
  11. +8
    -7
      src/zkey.rs
  12. +8
    -8
      tests/groth16.rs

+ 2
- 2
.github/workflows/ci.yml

@ -16,7 +16,7 @@ jobs:
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
profile: minimal profile: minimal
toolchain: 1.67.0
toolchain: 1.74.0
override: true override: true
# Install for Anvil # Install for Anvil
@ -48,7 +48,7 @@ jobs:
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
profile: minimal profile: minimal
toolchain: 1.67.0
toolchain: 1.74.0
override: true override: true
components: rustfmt, clippy components: rustfmt, clippy
- name: cargo fmt - name: cargo fmt

+ 1959
- 954
Cargo.lock
File diff suppressed because it is too large
View File


+ 2
- 1
Cargo.toml

@ -10,7 +10,8 @@ crate-type = ["cdylib", "rlib"]
[dependencies] [dependencies]
# WASM operations # WASM operations
wasmer = { version = "=2.3.0", default-features = false }
wasmer = "4.3.0"
wasmer-wasix = { version = "0.20.0" }
fnv = { version = "=1.0.7", default-features = false } fnv = { version = "=1.0.7", default-features = false }
num = { version = "=0.4.0" } num = { version = "=0.4.0" }
num-traits = { version = "=0.2.15", default-features = false } num-traits = { version = "=0.2.15", default-features = false }

+ 10
- 6
benches/groth16.rs

@ -6,6 +6,7 @@ use ark_std::rand::thread_rng;
use ark_bn254::Bn254; use ark_bn254::Bn254;
use ark_groth16::Groth16; use ark_groth16::Groth16;
use wasmer::Store;
use std::{collections::HashMap, fs::File}; use std::{collections::HashMap, fs::File};
@ -28,14 +29,17 @@ fn bench_groth(c: &mut Criterion, num_validators: u32, num_constraints: u32) {
inputs inputs
}; };
let mut wtns = WitnessCalculator::new(format!(
"./test-vectors/complex-circuit/complex-circuit-{}-{}.wasm",
i, j
))
let mut store = Store::default();
let mut wtns = WitnessCalculator::new(
&mut store,
format!(
"./test-vectors/complex-circuit/complex-circuit-{}-{}.wasm",
i, j
),
)
.unwrap(); .unwrap();
let full_assignment = wtns let full_assignment = wtns
.calculate_witness_element::<Bn254, _>(inputs, false)
.calculate_witness_element::<Bn254, _>(&mut store, inputs, false)
.unwrap(); .unwrap();
let mut rng = thread_rng(); let mut rng = thread_rng();

+ 1
- 1
rust-toolchain.toml

@ -1,3 +1,3 @@
[toolchain] [toolchain]
channel = "stable" channel = "stable"
version = "1.67.0"
version = "1.74.0"

+ 30
- 9
src/circom/builder.rs

@ -1,36 +1,56 @@
use ark_ec::pairing::Pairing; use ark_ec::pairing::Pairing;
use std::{fs::File, path::Path}; use std::{fs::File, path::Path};
use wasmer::Store;
use super::{CircomCircuit, R1CS}; use super::{CircomCircuit, R1CS};
use num_bigint::BigInt; use num_bigint::BigInt;
use std::collections::HashMap; use std::collections::HashMap;
use crate::{circom::R1CSFile, witness::WitnessCalculator};
use crate::{
circom::R1CSFile,
witness::{Wasm, WitnessCalculator},
};
use color_eyre::Result; use color_eyre::Result;
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct CircomBuilder<E: Pairing> { pub struct CircomBuilder<E: Pairing> {
pub cfg: CircomConfig<E>, pub cfg: CircomConfig<E>,
pub inputs: HashMap<String, Vec<BigInt>>, pub inputs: HashMap<String, Vec<BigInt>>,
} }
// Add utils for creating this from files / directly from bytes // Add utils for creating this from files / directly from bytes
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct CircomConfig<E: Pairing> { pub struct CircomConfig<E: Pairing> {
pub r1cs: R1CS<E>, pub r1cs: R1CS<E>,
pub wtns: WitnessCalculator, pub wtns: WitnessCalculator,
pub store: Store,
pub sanity_check: bool, pub sanity_check: bool,
} }
impl<E: Pairing> CircomConfig<E> { impl<E: Pairing> CircomConfig<E> {
pub fn new(wtns: impl AsRef<Path>, r1cs: impl AsRef<Path>) -> Result<Self> { pub fn new(wtns: impl AsRef<Path>, r1cs: impl AsRef<Path>) -> Result<Self> {
let wtns = WitnessCalculator::new(wtns).unwrap();
let mut store = Store::default();
let wtns = WitnessCalculator::new(&mut store, wtns).unwrap();
let reader = File::open(r1cs)?; let reader = File::open(r1cs)?;
let r1cs = R1CSFile::new(reader)?.into(); let r1cs = R1CSFile::new(reader)?.into();
Ok(Self { Ok(Self {
wtns, wtns,
r1cs, r1cs,
store,
sanity_check: false,
})
}
pub fn new_from_wasm(wasm: Wasm, r1cs: impl AsRef<Path>) -> Result<Self> {
let mut store = Store::default();
let wtns = WitnessCalculator::new_from_wasm(&mut store, wasm).unwrap();
let reader = File::open(r1cs)?;
let r1cs = R1CSFile::new(reader)?.into();
Ok(Self {
wtns,
r1cs,
store,
sanity_check: false, sanity_check: false,
}) })
} }
@ -48,7 +68,7 @@ impl CircomBuilder {
/// Pushes a Circom input at the specified name. /// Pushes a Circom input at the specified name.
pub fn push_input<T: Into<BigInt>>(&mut self, name: impl ToString, val: T) { pub fn push_input<T: Into<BigInt>>(&mut self, name: impl ToString, val: T) {
let values = self.inputs.entry(name.to_string()).or_insert_with(Vec::new);
let values = self.inputs.entry(name.to_string()).or_default();
values.push(val.into()); values.push(val.into());
} }
@ -72,10 +92,11 @@ impl CircomBuilder {
let mut circom = self.setup(); let mut circom = self.setup();
// calculate the witness // calculate the witness
let witness = self
.cfg
.wtns
.calculate_witness_element::<E, _>(self.inputs, self.cfg.sanity_check)?;
let witness = self.cfg.wtns.calculate_witness_element::<E, _>(
&mut self.cfg.store,
self.inputs,
self.cfg.sanity_check,
)?;
circom.witness = Some(witness); circom.witness = Some(witness);
// sanity check // sanity check

+ 2
- 2
src/circom/circuit.rs

@ -93,8 +93,8 @@ mod tests {
use ark_bn254::{Bn254, Fr}; use ark_bn254::{Bn254, Fr};
use ark_relations::r1cs::ConstraintSystem; use ark_relations::r1cs::ConstraintSystem;
#[test]
fn satisfied() {
#[tokio::test]
async fn satisfied() {
let cfg = CircomConfig::<Bn254>::new( let cfg = CircomConfig::<Bn254>::new(
"./test-vectors/mycircuit.wasm", "./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs", "./test-vectors/mycircuit.r1cs",

+ 85
- 60
src/witness/circom.rs

@ -1,79 +1,105 @@
use color_eyre::Result; use color_eyre::Result;
use wasmer::{Function, Instance, Value};
use wasmer::{Exports, Function, Memory, Store, Value};
#[derive(Clone, Debug)]
pub struct Wasm(Instance);
#[derive(Debug)]
pub struct Wasm {
pub exports: Exports,
pub memory: Memory,
}
pub trait CircomBase { pub trait CircomBase {
fn init(&self, sanity_check: bool) -> Result<()>;
fn init(&self, store: &mut Store, sanity_check: bool) -> Result<()>;
fn func(&self, name: &str) -> &Function; fn func(&self, name: &str) -> &Function;
fn get_n_vars(&self) -> Result<u32>;
fn get_u32(&self, name: &str) -> Result<u32>;
fn get_n_vars(&self, store: &mut Store) -> Result<u32>;
fn get_u32(&self, store: &mut Store, name: &str) -> Result<u32>;
// Only exists natively in Circom2, hardcoded for Circom // Only exists natively in Circom2, hardcoded for Circom
fn get_version(&self) -> Result<u32>;
fn get_version(&self, store: &mut Store) -> Result<u32>;
} }
pub trait Circom1 { pub trait Circom1 {
fn get_ptr_witness(&self, w: u32) -> Result<u32>;
fn get_fr_len(&self) -> Result<u32>;
fn get_ptr_witness(&self, store: &mut Store, w: u32) -> Result<u32>;
fn get_fr_len(&self, store: &mut Store) -> Result<u32>;
fn get_signal_offset32( fn get_signal_offset32(
&self, &self,
store: &mut Store,
p_sig_offset: u32, p_sig_offset: u32,
component: u32, component: u32,
hash_msb: u32, hash_msb: u32,
hash_lsb: u32, hash_lsb: u32,
) -> Result<()>; ) -> Result<()>;
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()>;
fn get_ptr_raw_prime(&self) -> Result<u32>;
fn set_signal(
&self,
store: &mut Store,
c_idx: u32,
component: u32,
signal: u32,
p_val: u32,
) -> Result<()>;
fn get_ptr_raw_prime(&self, store: &mut Store) -> Result<u32>;
} }
pub trait Circom2 { pub trait Circom2 {
fn get_field_num_len32(&self) -> Result<u32>;
fn get_raw_prime(&self) -> Result<()>;
fn read_shared_rw_memory(&self, i: u32) -> Result<u32>;
fn write_shared_rw_memory(&self, i: u32, v: u32) -> Result<()>;
fn set_input_signal(&self, hmsb: u32, hlsb: u32, pos: u32) -> Result<()>;
fn get_witness(&self, i: u32) -> Result<()>;
fn get_witness_size(&self) -> Result<u32>;
fn get_field_num_len32(&self, store: &mut Store) -> Result<u32>;
fn get_raw_prime(&self, store: &mut Store) -> Result<()>;
fn read_shared_rw_memory(&self, store: &mut Store, i: u32) -> Result<u32>;
fn write_shared_rw_memory(&self, store: &mut Store, i: u32, v: u32) -> Result<()>;
fn set_input_signal(&self, store: &mut Store, hmsb: u32, hlsb: u32, pos: u32) -> Result<()>;
fn get_witness(&self, store: &mut Store, i: u32) -> Result<()>;
fn get_witness_size(&self, store: &mut Store) -> Result<u32>;
} }
impl Circom1 for Wasm { impl Circom1 for Wasm {
fn get_fr_len(&self) -> Result<u32> {
self.get_u32("getFrLen")
fn get_fr_len(&self, store: &mut Store) -> Result<u32> {
self.get_u32(store, "getFrLen")
} }
fn get_ptr_raw_prime(&self) -> Result<u32> {
self.get_u32("getPRawPrime")
fn get_ptr_raw_prime(&self, store: &mut Store) -> Result<u32> {
self.get_u32(store, "getPRawPrime")
} }
fn get_ptr_witness(&self, w: u32) -> Result<u32> {
fn get_ptr_witness(&self, store: &mut Store, w: u32) -> Result<u32> {
let func = self.func("getPWitness"); let func = self.func("getPWitness");
let res = func.call(&[w.into()])?;
let res = func.call(store, &[w.into()])?;
Ok(res[0].unwrap_i32() as u32) Ok(res[0].unwrap_i32() as u32)
} }
fn get_signal_offset32( fn get_signal_offset32(
&self, &self,
store: &mut Store,
p_sig_offset: u32, p_sig_offset: u32,
component: u32, component: u32,
hash_msb: u32, hash_msb: u32,
hash_lsb: u32, hash_lsb: u32,
) -> Result<()> { ) -> Result<()> {
let func = self.func("getSignalOffset32"); let func = self.func("getSignalOffset32");
func.call(&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
])?;
func.call(
store,
&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
],
)?;
Ok(()) Ok(())
} }
fn set_signal(&self, c_idx: u32, component: u32, signal: u32, p_val: u32) -> Result<()> {
fn set_signal(
&self,
store: &mut Store,
c_idx: u32,
component: u32,
signal: u32,
p_val: u32,
) -> Result<()> {
let func = self.func("setSignal"); let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
func.call(
store,
&[c_idx.into(), component.into(), signal.into(), p_val.into()],
)?;
Ok(()) Ok(())
} }
@ -81,80 +107,79 @@ impl Circom1 for Wasm {
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
impl Circom2 for Wasm { impl Circom2 for Wasm {
fn get_field_num_len32(&self) -> Result<u32> {
self.get_u32("getFieldNumLen32")
fn get_field_num_len32(&self, store: &mut Store) -> Result<u32> {
self.get_u32(store, "getFieldNumLen32")
} }
fn get_raw_prime(&self) -> Result<()> {
fn get_raw_prime(&self, store: &mut Store) -> Result<()> {
let func = self.func("getRawPrime"); let func = self.func("getRawPrime");
func.call(&[])?;
func.call(store, &[])?;
Ok(()) Ok(())
} }
fn read_shared_rw_memory(&self, i: u32) -> Result<u32> {
fn read_shared_rw_memory(&self, store: &mut Store, i: u32) -> Result<u32> {
let func = self.func("readSharedRWMemory"); let func = self.func("readSharedRWMemory");
let result = func.call(&[i.into()])?;
let result = func.call(store, &[i.into()])?;
Ok(result[0].unwrap_i32() as u32) Ok(result[0].unwrap_i32() as u32)
} }
fn write_shared_rw_memory(&self, i: u32, v: u32) -> Result<()> {
fn write_shared_rw_memory(&self, store: &mut Store, i: u32, v: u32) -> Result<()> {
let func = self.func("writeSharedRWMemory"); let func = self.func("writeSharedRWMemory");
func.call(&[i.into(), v.into()])?;
func.call(store, &[i.into(), v.into()])?;
Ok(()) Ok(())
} }
fn set_input_signal(&self, hmsb: u32, hlsb: u32, pos: u32) -> Result<()> {
fn set_input_signal(&self, store: &mut Store, hmsb: u32, hlsb: u32, pos: u32) -> Result<()> {
let func = self.func("setInputSignal"); let func = self.func("setInputSignal");
func.call(&[hmsb.into(), hlsb.into(), pos.into()])?;
func.call(store, &[hmsb.into(), hlsb.into(), pos.into()])?;
Ok(()) Ok(())
} }
fn get_witness(&self, i: u32) -> Result<()> {
fn get_witness(&self, store: &mut Store, i: u32) -> Result<()> {
let func = self.func("getWitness"); let func = self.func("getWitness");
func.call(&[i.into()])?;
func.call(store, &[i.into()])?;
Ok(()) Ok(())
} }
fn get_witness_size(&self) -> Result<u32> {
self.get_u32("getWitnessSize")
fn get_witness_size(&self, store: &mut Store) -> Result<u32> {
self.get_u32(store, "getWitnessSize")
} }
} }
impl CircomBase for Wasm { impl CircomBase for Wasm {
fn init(&self, sanity_check: bool) -> Result<()> {
fn init(&self, store: &mut Store, sanity_check: bool) -> Result<()> {
let func = self.func("init"); let func = self.func("init");
func.call(&[Value::I32(sanity_check as i32)])?;
func.call(store, &[Value::I32(sanity_check as i32)])?;
Ok(()) Ok(())
} }
fn get_n_vars(&self) -> Result<u32> {
self.get_u32("getNVars")
fn get_n_vars(&self, store: &mut Store) -> Result<u32> {
self.get_u32(store, "getNVars")
} }
// Default to version 1 if it isn't explicitly defined // Default to version 1 if it isn't explicitly defined
fn get_version(&self) -> Result<u32> {
match self.0.exports.get_function("getVersion") {
Ok(func) => Ok(func.call(&[])?[0].unwrap_i32() as u32),
fn get_version(&self, store: &mut Store) -> Result<u32> {
match self.exports.get_function("getVersion") {
Ok(func) => Ok(func.call(store, &[])?[0].unwrap_i32() as u32),
Err(_) => Ok(1), Err(_) => Ok(1),
} }
} }
fn get_u32(&self, name: &str) -> Result<u32> {
let func = self.func(name);
let result = func.call(&[])?;
fn get_u32(&self, store: &mut Store, name: &str) -> Result<u32> {
let func = &self.func(name);
let result = func.call(store, &[])?;
Ok(result[0].unwrap_i32() as u32) Ok(result[0].unwrap_i32() as u32)
} }
fn func(&self, name: &str) -> &Function { fn func(&self, name: &str) -> &Function {
self.0
.exports
self.exports
.get_function(name) .get_function(name)
.unwrap_or_else(|_| panic!("function {} not found", name)) .unwrap_or_else(|_| panic!("function {} not found", name))
} }
} }
impl Wasm { impl Wasm {
pub fn new(instance: Instance) -> Self {
Self(instance)
pub fn new(exports: Exports, memory: Memory) -> Self {
Self { exports, memory }
} }
} }

+ 93
- 81
src/witness/memory.rs

@ -1,7 +1,7 @@
//! Safe-ish interface for reading and writing specific types to the WASM runtime's memory //! Safe-ish interface for reading and writing specific types to the WASM runtime's memory
use ark_serialize::CanonicalDeserialize; use ark_serialize::CanonicalDeserialize;
use num_traits::ToPrimitive; use num_traits::ToPrimitive;
use wasmer::{Memory, MemoryView};
use wasmer::{Memory, MemoryAccessError, MemoryView, Store};
// TODO: Decide whether we want Ark here or if it should use a generic BigInt package // TODO: Decide whether we want Ark here or if it should use a generic BigInt package
use ark_bn254::FrConfig; use ark_bn254::FrConfig;
@ -11,10 +11,11 @@ use ark_ff::{BigInteger, BigInteger256, Zero};
use num_bigint::{BigInt, BigUint}; use num_bigint::{BigInt, BigUint};
use color_eyre::Result; use color_eyre::Result;
use std::io::Cursor;
use std::str::FromStr; use std::str::FromStr;
use std::{convert::TryFrom, ops::Deref}; use std::{convert::TryFrom, ops::Deref};
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct SafeMemory { pub struct SafeMemory {
pub memory: Memory, pub memory: Memory,
pub prime: BigInt, pub prime: BigInt,
@ -38,10 +39,9 @@ impl SafeMemory {
pub fn new(memory: Memory, n32: usize, prime: BigInt) -> Self { pub fn new(memory: Memory, n32: usize, prime: BigInt) -> Self {
// TODO: Figure out a better way to calculate these // TODO: Figure out a better way to calculate these
let short_max = BigInt::from(0x8000_0000u64); let short_max = BigInt::from(0x8000_0000u64);
let short_min = BigInt::from_biguint(
num_bigint::Sign::NoSign,
BigUint::try_from(FrConfig::MODULUS).unwrap(),
) - &short_max;
let short_min =
BigInt::from_biguint(num_bigint::Sign::NoSign, BigUint::from(FrConfig::MODULUS))
- &short_max;
let r_inv = BigInt::from_str( let r_inv = BigInt::from_str(
"9915499612839321149637521777990102151350674507940716049588462388200839649614", "9915499612839321149637521777990102151350674507940716049588462388200839649614",
) )
@ -59,96 +59,103 @@ impl SafeMemory {
} }
/// Gets an immutable view to the memory in 32 byte chunks /// Gets an immutable view to the memory in 32 byte chunks
pub fn view(&self) -> MemoryView<u32> {
self.memory.view()
pub fn view<'a>(&self, store: &'a mut Store) -> MemoryView<'a> {
self.memory.view(store)
} }
/// Returns the next free position in the memory /// Returns the next free position in the memory
pub fn free_pos(&self) -> u32 {
self.view()[0].get()
pub fn free_pos(&self, store: &mut Store) -> Result<u32, MemoryAccessError> {
self.read_u32(store, 0)
} }
/// Sets the next free position in the memory /// Sets the next free position in the memory
pub fn set_free_pos(&mut self, ptr: u32) {
self.write_u32(0, ptr);
pub fn set_free_pos(&self, store: &mut Store, ptr: u32) -> Result<(), MemoryAccessError> {
self.write_u32(store, 0, ptr)
} }
/// Allocates a U32 in memory /// Allocates a U32 in memory
pub fn alloc_u32(&mut self) -> u32 {
let p = self.free_pos();
self.set_free_pos(p + 8);
p
pub fn alloc_u32(&self, store: &mut Store) -> Result<u32, MemoryAccessError> {
let p = self.free_pos(store)?;
self.set_free_pos(store, p + 8)?;
Ok(p)
} }
/// Writes a u32 to the specified memory offset /// Writes a u32 to the specified memory offset
pub fn write_u32(&mut self, ptr: usize, num: u32) {
let buf = unsafe { self.memory.data_unchecked_mut() };
buf[ptr..ptr + std::mem::size_of::<u32>()].copy_from_slice(&num.to_le_bytes());
pub fn write_u32(
&self,
store: &mut Store,
ptr: usize,
num: u32,
) -> Result<(), MemoryAccessError> {
let bytes = num.to_le_bytes();
self.view(store).write(ptr as u64, &bytes)
} }
/// Reads a u32 from the specified memory offset /// Reads a u32 from the specified memory offset
pub fn read_u32(&self, ptr: usize) -> u32 {
let buf = unsafe { self.memory.data_unchecked() };
pub fn read_u32(&self, store: &mut Store, ptr: usize) -> Result<u32, MemoryAccessError> {
let mut bytes = [0; 4]; let mut bytes = [0; 4];
bytes.copy_from_slice(&buf[ptr..ptr + std::mem::size_of::<u32>()]);
self.view(store).read(ptr as u64, &mut bytes)?;
Ok(u32::from_le_bytes(bytes))
}
u32::from_le_bytes(bytes)
pub fn read_byte(&self, store: &mut Store, ptr: usize) -> Result<u8, MemoryAccessError> {
let mut bytes = [0; 1];
self.view(store).read(ptr as u64, &mut bytes)?;
Ok(u8::from_le_bytes(bytes))
} }
/// Allocates `self.n32 * 4 + 8` bytes in the memory /// Allocates `self.n32 * 4 + 8` bytes in the memory
pub fn alloc_fr(&mut self) -> u32 {
let p = self.free_pos();
self.set_free_pos(p + self.n32 as u32 * 4 + 8);
p
pub fn alloc_fr(&self, store: &mut Store) -> Result<u32, MemoryAccessError> {
let p = self.free_pos(store)?;
self.set_free_pos(store, p + self.n32 as u32 * 4 + 8)?;
Ok(p)
} }
/// Writes a Field Element to memory at the specified offset, truncating /// Writes a Field Element to memory at the specified offset, truncating
/// to smaller u32 types if needed and adjusting the sign via 2s complement /// to smaller u32 types if needed and adjusting the sign via 2s complement
pub fn write_fr(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
pub fn write_fr(&self, store: &mut Store, ptr: usize, fr: &BigInt) -> Result<()> {
if fr < &self.short_max && fr > &self.short_min { if fr < &self.short_max && fr > &self.short_min {
if fr >= &BigInt::zero() { if fr >= &BigInt::zero() {
self.write_short_positive(ptr, fr)?;
self.write_short_positive(store, ptr, fr)?;
} else { } else {
self.write_short_negative(ptr, fr)?;
self.write_short_negative(store, ptr, fr)?;
} }
} else { } else {
self.write_long_normal(ptr, fr)?;
self.write_long_normal(store, ptr, fr)?;
} }
Ok(()) Ok(())
} }
/// Reads a Field Element from the memory at the specified offset /// Reads a Field Element from the memory at the specified offset
pub fn read_fr(&self, ptr: usize) -> Result<BigInt> {
let view = self.memory.view::<u8>();
pub fn read_fr(&self, store: &mut Store, ptr: usize) -> Result<BigInt, MemoryAccessError> {
let test_byte = self.read_byte(store, ptr + 4 + 3)?;
let test_byte2 = self.read_byte(store, ptr + 3)?;
let res = if view[ptr + 4 + 3].get() & 0x80 != 0 {
let mut num = self.read_big(ptr + 8, self.n32)?;
if view[ptr + 4 + 3].get() & 0x40 != 0 {
if test_byte & 0x80 != 0 {
let mut num = self.read_big(store, ptr + 8, self.n32)?;
if test_byte & 0x40 != 0 {
num = (num * &self.r_inv) % &self.prime num = (num * &self.r_inv) % &self.prime
} }
num
} else if view[ptr + 3].get() & 0x40 != 0 {
let mut num = self.read_u32(ptr).into();
Ok(num)
} else if test_byte2 & 0x40 != 0 {
let mut num = self.read_u32(store, ptr).map(|x| x.into())?;
// handle small negative // handle small negative
num -= BigInt::from(0x100000000i64); num -= BigInt::from(0x100000000i64);
num
Ok(num)
} else { } else {
self.read_u32(ptr).into()
};
Ok(res)
self.read_u32(store, ptr).map(|x| x.into())
}
} }
fn write_short_positive(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
fn write_short_positive(&self, store: &mut Store, ptr: usize, fr: &BigInt) -> Result<()> {
let num = fr.to_i32().expect("not a short positive"); let num = fr.to_i32().expect("not a short positive");
self.write_u32(ptr, num as u32);
self.write_u32(ptr + 4, 0);
self.write_u32(store, ptr, num as u32)?;
self.write_u32(store, ptr + 4, 0)?;
Ok(()) Ok(())
} }
fn write_short_negative(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
fn write_short_negative(&self, store: &mut Store, ptr: usize, fr: &BigInt) -> Result<()> {
// 2s complement // 2s complement
let num = fr - &self.short_min; let num = fr - &self.short_min;
let num = num - &self.short_max; let num = num - &self.short_max;
@ -158,40 +165,43 @@ impl SafeMemory {
.to_u32() .to_u32()
.expect("could not cast as u32 (should never happen)"); .expect("could not cast as u32 (should never happen)");
self.write_u32(ptr, num);
self.write_u32(ptr + 4, 0);
self.write_u32(store, ptr, num)?;
self.write_u32(store, ptr + 4, 0)?;
Ok(()) Ok(())
} }
fn write_long_normal(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
self.write_u32(ptr, 0);
self.write_u32(ptr + 4, i32::MIN as u32); // 0x80000000
self.write_big(ptr + 8, fr)?;
fn write_long_normal(&self, store: &mut Store, ptr: usize, fr: &BigInt) -> Result<()> {
self.write_u32(store, ptr, 0)?;
self.write_u32(store, ptr + 4, i32::MIN as u32)?; // 0x80000000
self.write_big(store, ptr + 8, fr)?;
Ok(()) Ok(())
} }
fn write_big(&self, ptr: usize, num: &BigInt) -> Result<()> {
let buf = unsafe { self.memory.data_unchecked_mut() };
// TODO: How do we handle negative bignums?
fn write_big(
&self,
store: &mut Store,
ptr: usize,
num: &BigInt,
) -> Result<(), MemoryAccessError> {
let (_, num) = num.clone().into_parts(); let (_, num) = num.clone().into_parts();
let num = BigInteger256::try_from(num).unwrap(); let num = BigInteger256::try_from(num).unwrap();
let bytes = num.to_bytes_le(); let bytes = num.to_bytes_le();
let len = bytes.len();
buf[ptr..ptr + len].copy_from_slice(&bytes);
Ok(())
self.view(store).write(ptr as u64, &bytes)
} }
/// Reads `num_bytes * 32` from the specified memory offset in a Big Integer /// Reads `num_bytes * 32` from the specified memory offset in a Big Integer
pub fn read_big(&self, ptr: usize, num_bytes: usize) -> Result<BigInt> {
let buf = unsafe { self.memory.data_unchecked() };
let buf = &buf[ptr..ptr + num_bytes * 32];
pub fn read_big(
&self,
store: &mut Store,
ptr: usize,
num_bytes: usize,
) -> Result<BigInt, MemoryAccessError> {
let mut buf = vec![0; num_bytes * 32];
self.view(store).read(ptr as u64, &mut buf)?;
// TODO: Is there a better way to read big integers? // TODO: Is there a better way to read big integers?
let big = BigInteger256::deserialize_uncompressed(buf).unwrap();
let big = BigUint::try_from(big).unwrap();
let big = BigInteger256::deserialize_uncompressed(&mut Cursor::new(buf)).unwrap();
let big = BigUint::from(big);
Ok(big.into()) Ok(big.into())
} }
} }
@ -210,20 +220,22 @@ mod tests {
use std::str::FromStr; use std::str::FromStr;
use wasmer::{MemoryType, Store}; use wasmer::{MemoryType, Store};
fn new() -> SafeMemory {
SafeMemory::new(
Memory::new(&Store::default(), MemoryType::new(1, None, false)).unwrap(),
fn new() -> (SafeMemory, Store) {
let mut store = Store::default();
let mem = SafeMemory::new(
Memory::new(&mut store, MemoryType::new(1, None, false)).unwrap(),
2, 2,
BigInt::from_str( BigInt::from_str(
"21888242871839275222246405745257275088548364400416034343698204186575808495617", "21888242871839275222246405745257275088548364400416034343698204186575808495617",
) )
.unwrap(), .unwrap(),
)
);
(mem, store)
} }
#[test] #[test]
fn i32_bounds() { fn i32_bounds() {
let mem = new();
let (mem, _) = new();
let i32_max = i32::MAX as i64 + 1; let i32_max = i32::MAX as i64 + 1;
assert_eq!(mem.short_min.to_i64().unwrap(), -i32_max); assert_eq!(mem.short_min.to_i64().unwrap(), -i32_max);
assert_eq!(mem.short_max.to_i64().unwrap(), i32_max); assert_eq!(mem.short_max.to_i64().unwrap(), i32_max);
@ -231,14 +243,14 @@ mod tests {
#[test] #[test]
fn read_write_32() { fn read_write_32() {
let mut mem = new();
let (mem, mut store) = new();
let num = u32::MAX; let num = u32::MAX;
let inp = mem.read_u32(0);
let inp = mem.read_u32(&mut store, 0).unwrap();
assert_eq!(inp, 0); assert_eq!(inp, 0);
mem.write_u32(0, num);
let inp = mem.read_u32(0);
mem.write_u32(&mut store, 0, num).unwrap();
let inp = mem.read_u32(&mut store, 0).unwrap();
assert_eq!(inp, num); assert_eq!(inp, num);
} }
@ -265,9 +277,9 @@ mod tests {
} }
fn read_write_fr(num: BigInt) { fn read_write_fr(num: BigInt) {
let mut mem = new();
mem.write_fr(0, &num).unwrap();
let res = mem.read_fr(0).unwrap();
let (mem, mut store) = new();
mem.write_fr(&mut store, 0, &num).unwrap();
let res = mem.read_fr(&mut store, 0).unwrap();
assert_eq!(res, num); assert_eq!(res, num);
} }
} }

+ 106
- 78
src/witness/witness_calculator.rs

@ -3,6 +3,7 @@ use color_eyre::Result;
use num_bigint::BigInt; use num_bigint::BigInt;
use num_traits::Zero; use num_traits::Zero;
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store}; use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, RuntimeError, Store};
use wasmer_wasix::WasiEnv;
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
use num::ToPrimitive; use num::ToPrimitive;
@ -11,7 +12,7 @@ use super::Circom1;
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
use super::Circom2; use super::Circom2;
#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct WitnessCalculator { pub struct WitnessCalculator {
pub instance: Wasm, pub instance: Wasm,
pub memory: Option<SafeMemory>, pub memory: Option<SafeMemory>,
@ -52,20 +53,21 @@ fn to_array32(s: &BigInt, size: usize) -> Vec {
} }
impl WitnessCalculator { impl WitnessCalculator {
pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self> {
Self::from_file(path)
pub fn new(store: &mut Store, path: impl AsRef<std::path::Path>) -> Result<Self> {
Self::from_file(store, path)
} }
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
let store = Store::default();
pub fn from_file(store: &mut Store, path: impl AsRef<std::path::Path>) -> Result<Self> {
let module = Module::from_file(&store, path)?; let module = Module::from_file(&store, path)?;
Self::from_module(module)
Self::from_module(store, module)
} }
pub fn from_module(module: Module) -> Result<Self> {
let store = module.store();
pub fn from_module(store: &mut Store, module: Module) -> Result<Self> {
let wasm = Self::make_wasm_runtime(store, module)?;
Self::new_from_wasm(store, wasm)
}
// Set up the memory
pub fn make_wasm_runtime(store: &mut Store, module: Module) -> Result<Wasm> {
let memory = Memory::new(store, MemoryType::new(2000, None, false)).unwrap(); let memory = Memory::new(store, MemoryType::new(2000, None, false)).unwrap();
let import_object = imports! { let import_object = imports! {
"env" => { "env" => {
@ -85,18 +87,28 @@ impl WitnessCalculator {
"writeBufferMessage" => runtime::write_buffer_message(store), "writeBufferMessage" => runtime::write_buffer_message(store),
} }
}; };
let instance = Wasm::new(Instance::new(&module, &import_object)?);
let version = instance.get_version().unwrap_or(1);
let instance = Instance::new(store, &module, &import_object)?;
let exports = instance.exports.clone();
let mut wasi_env = WasiEnv::builder("calculateWitness").finalize(store)?;
wasi_env.initialize_with_memory(store, instance, Some(memory.clone()), false)?;
let wasm = Wasm::new(exports, memory);
Ok(wasm)
}
pub fn new_from_wasm(store: &mut Store, wasm: Wasm) -> Result<Self> {
let version = wasm.get_version(store).unwrap_or(1);
// Circom 2 feature flag with version 2 // Circom 2 feature flag with version 2
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
fn new_circom2(instance: Wasm, version: u32) -> Result<WitnessCalculator> {
let n32 = instance.get_field_num_len32()?;
instance.get_raw_prime()?;
fn new_circom2(
instance: Wasm,
store: &mut Store,
version: u32,
) -> Result<WitnessCalculator> {
let n32 = instance.get_field_num_len32(store)?;
instance.get_raw_prime(store)?;
let mut arr = vec![0; n32 as usize]; let mut arr = vec![0; n32 as usize];
for i in 0..n32 { for i in 0..n32 {
let res = instance.read_shared_rw_memory(i)?;
let res = instance.read_shared_rw_memory(store, i)?;
arr[(n32 as usize) - (i as usize) - 1] = res; arr[(n32 as usize) - (i as usize) - 1] = res;
} }
let prime = from_array32(arr); let prime = from_array32(arr);
@ -112,12 +124,17 @@ impl WitnessCalculator {
}) })
} }
fn new_circom1(instance: Wasm, memory: Memory, version: u32) -> Result<WitnessCalculator> {
fn new_circom1(
instance: Wasm,
store: &mut Store,
version: u32,
) -> Result<WitnessCalculator> {
// Fallback to Circom 1 behavior // Fallback to Circom 1 behavior
let n32 = (instance.get_fr_len()? >> 2) - 2;
let mut safe_memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let ptr = instance.get_ptr_raw_prime()?;
let prime = safe_memory.read_big(ptr as usize, n32 as usize)?;
let n32 = (instance.get_fr_len(store)? >> 2) - 2;
let mut safe_memory =
SafeMemory::new(instance.memory.clone(), n32 as usize, BigInt::zero());
let ptr = instance.get_ptr_raw_prime(store)?;
let prime = safe_memory.read_big(store, ptr as usize, n32 as usize)?;
let n64 = ((prime.bits() - 1) / 64 + 1) as u32; let n64 = ((prime.bits() - 1) / 64 + 1) as u32;
safe_memory.prime = prime.clone(); safe_memory.prime = prime.clone();
@ -141,8 +158,9 @@ impl WitnessCalculator {
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] { if #[cfg(feature = "circom-2")] {
match version { match version {
2 => new_circom2(instance, version),
1 => new_circom1(instance, memory, version),
2 => new_circom2(wasm, store, version),
1 => new_circom1(wasm, store, version),
_ => panic!("Unknown Circom version") _ => panic!("Unknown Circom version")
} }
} else { } else {
@ -153,16 +171,17 @@ impl WitnessCalculator {
pub fn calculate_witness<I: IntoIterator<Item = (String, Vec<BigInt>)>>( pub fn calculate_witness<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
&mut self, &mut self,
store: &mut Store,
inputs: I, inputs: I,
sanity_check: bool, sanity_check: bool,
) -> Result<Vec<BigInt>> { ) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?;
self.instance.init(store, sanity_check)?;
cfg_if::cfg_if! { cfg_if::cfg_if! {
if #[cfg(feature = "circom-2")] { if #[cfg(feature = "circom-2")] {
match self.circom_version { match self.circom_version {
2 => self.calculate_witness_circom2(inputs, sanity_check),
1 => self.calculate_witness_circom1(inputs, sanity_check),
2 => self.calculate_witness_circom2(store, inputs),
1 => self.calculate_witness_circom1(store, inputs),
_ => panic!("Unknown Circom version") _ => panic!("Unknown Circom version")
} }
} else { } else {
@ -174,48 +193,50 @@ impl WitnessCalculator {
// Circom 1 default behavior // Circom 1 default behavior
fn calculate_witness_circom1<I: IntoIterator<Item = (String, Vec<BigInt>)>>( fn calculate_witness_circom1<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
&mut self, &mut self,
store: &mut Store,
inputs: I, inputs: I,
sanity_check: bool,
) -> Result<Vec<BigInt>> { ) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?;
let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos();
let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32();
let p_fr = self.memory.as_mut().unwrap().alloc_fr();
let old_mem_free_pos = self.memory.as_ref().unwrap().free_pos(store)?;
let p_sig_offset = self.memory.as_mut().unwrap().alloc_u32(store)?;
let p_fr = self.memory.as_mut().unwrap().alloc_fr(store)?;
// allocate the inputs // allocate the inputs
for (name, values) in inputs.into_iter() { for (name, values) in inputs.into_iter() {
let (msb, lsb) = fnv(&name); let (msb, lsb) = fnv(&name);
self.instance self.instance
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
.get_signal_offset32(store, p_sig_offset, 0, msb, lsb)?;
let sig_offset = self let sig_offset = self
.memory .memory
.as_ref() .as_ref()
.unwrap() .unwrap()
.read_u32(p_sig_offset as usize) as usize;
.read_u32(store, p_sig_offset as usize)
.unwrap() as usize;
for (i, value) in values.into_iter().enumerate() { for (i, value) in values.into_iter().enumerate() {
self.memory self.memory
.as_mut() .as_mut()
.unwrap() .unwrap()
.write_fr(p_fr as usize, &value)?;
.write_fr(store, p_fr as usize, &value)?;
self.instance self.instance
.set_signal(0, 0, (sig_offset + i) as u32, p_fr)?;
.set_signal(store, 0, 0, (sig_offset + i) as u32, p_fr)?;
} }
} }
let mut w = Vec::new(); let mut w = Vec::new();
let n_vars = self.instance.get_n_vars()?;
let n_vars = self.instance.get_n_vars(store)?;
for i in 0..n_vars { for i in 0..n_vars {
let ptr = self.instance.get_ptr_witness(i)? as usize;
let el = self.memory.as_ref().unwrap().read_fr(ptr)?;
let ptr = self.instance.get_ptr_witness(store, i)? as usize;
let el = self.memory.as_ref().unwrap().read_fr(store, ptr)?;
w.push(el); w.push(el);
} }
self.memory.as_mut().unwrap().set_free_pos(old_mem_free_pos);
self.memory
.as_mut()
.unwrap()
.set_free_pos(store, old_mem_free_pos)?;
Ok(w) Ok(w)
} }
@ -224,12 +245,10 @@ impl WitnessCalculator {
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
fn calculate_witness_circom2<I: IntoIterator<Item = (String, Vec<BigInt>)>>( fn calculate_witness_circom2<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
&mut self, &mut self,
store: &mut Store,
inputs: I, inputs: I,
sanity_check: bool,
) -> Result<Vec<BigInt>> { ) -> Result<Vec<BigInt>> {
self.instance.init(sanity_check)?;
let n32 = self.instance.get_field_num_len32()?;
let n32 = self.instance.get_field_num_len32(store)?;
// allocate the inputs // allocate the inputs
for (name, values) in inputs.into_iter() { for (name, values) in inputs.into_iter() {
@ -238,21 +257,25 @@ impl WitnessCalculator {
for (i, value) in values.into_iter().enumerate() { for (i, value) in values.into_iter().enumerate() {
let f_arr = to_array32(&value, n32 as usize); let f_arr = to_array32(&value, n32 as usize);
for j in 0..n32 { for j in 0..n32 {
self.instance
.write_shared_rw_memory(j, f_arr[(n32 as usize) - 1 - (j as usize)])?;
self.instance.write_shared_rw_memory(
store,
j,
f_arr[(n32 as usize) - 1 - (j as usize)],
)?;
} }
self.instance.set_input_signal(msb, lsb, i as u32)?;
self.instance.set_input_signal(store, msb, lsb, i as u32)?;
} }
} }
let mut w = Vec::new(); let mut w = Vec::new();
let witness_size = self.instance.get_witness_size()?;
let witness_size = self.instance.get_witness_size(store)?;
for i in 0..witness_size { for i in 0..witness_size {
self.instance.get_witness(i)?;
self.instance.get_witness(store, i)?;
let mut arr = vec![0; n32 as usize]; let mut arr = vec![0; n32 as usize];
for j in 0..n32 { for j in 0..n32 {
arr[(n32 as usize) - 1 - (j as usize)] = self.instance.read_shared_rw_memory(j)?;
arr[(n32 as usize) - 1 - (j as usize)] =
self.instance.read_shared_rw_memory(store, j)?;
} }
w.push(from_array32(arr)); w.push(from_array32(arr));
} }
@ -265,11 +288,12 @@ impl WitnessCalculator {
I: IntoIterator<Item = (String, Vec<BigInt>)>, I: IntoIterator<Item = (String, Vec<BigInt>)>,
>( >(
&mut self, &mut self,
store: &mut Store,
inputs: I, inputs: I,
sanity_check: bool, sanity_check: bool,
) -> Result<Vec<E::ScalarField>> { ) -> Result<Vec<E::ScalarField>> {
use ark_ff::PrimeField; use ark_ff::PrimeField;
let witness = self.calculate_witness(inputs, sanity_check)?;
let witness = self.calculate_witness(store, inputs, sanity_check)?;
let modulus = <E::ScalarField as PrimeField>::MODULUS; let modulus = <E::ScalarField as PrimeField>::MODULUS;
// convert it to field elements // convert it to field elements
@ -295,7 +319,7 @@ impl WitnessCalculator {
mod runtime { mod runtime {
use super::*; use super::*;
pub fn error(store: &Store) -> Function {
pub fn error(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]
fn func(a: i32, b: i32, c: i32, d: i32, e: i32, f: i32) -> Result<(), RuntimeError> { fn func(a: i32, b: i32, c: i32, d: i32, e: i32, f: i32) -> Result<(), RuntimeError> {
@ -304,47 +328,47 @@ mod runtime {
println!("runtime error, exiting early: {a} {b} {c} {d} {e} {f}",); println!("runtime error, exiting early: {a} {b} {c} {d} {e} {f}",);
Err(RuntimeError::user(Box::new(ExitCode(1)))) Err(RuntimeError::user(Box::new(ExitCode(1))))
} }
Function::new_native(store, func)
Function::new_typed(store, func)
} }
// Circom 2.0 // Circom 2.0
pub fn exception_handler(store: &Store) -> Function {
pub fn exception_handler(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func(a: i32) {} fn func(a: i32) {}
Function::new_native(store, func)
Function::new_typed(store, func)
} }
// Circom 2.0 // Circom 2.0
pub fn show_memory(store: &Store) -> Function {
pub fn show_memory(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func() {} fn func() {}
Function::new_native(store, func)
Function::new_typed(store, func)
} }
// Circom 2.0 // Circom 2.0
pub fn print_error_message(store: &Store) -> Function {
pub fn print_error_message(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func() {} fn func() {}
Function::new_native(store, func)
Function::new_typed(store, func)
} }
// Circom 2.0 // Circom 2.0
pub fn write_buffer_message(store: &Store) -> Function {
pub fn write_buffer_message(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func() {} fn func() {}
Function::new_native(store, func)
Function::new_typed(store, func)
} }
pub fn log_signal(store: &Store) -> Function {
pub fn log_signal(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func(a: i32, b: i32) {} fn func(a: i32, b: i32) {}
Function::new_native(store, func)
Function::new_typed(store, func)
} }
pub fn log_component(store: &Store) -> Function {
pub fn log_component(store: &mut Store) -> Function {
#[allow(unused)] #[allow(unused)]
fn func(a: i32) {} fn func(a: i32) {}
Function::new_native(store, func)
Function::new_typed(store, func)
} }
} }
@ -367,8 +391,8 @@ mod tests {
path.to_string_lossy().to_string() path.to_string_lossy().to_string()
} }
#[test]
fn multiplier_1() {
#[tokio::test]
async fn multiplier_1() {
run_test(TestCase { run_test(TestCase {
circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(), circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input1.json").as_str(), inputs_path: root_path("test-vectors/mycircuit-input1.json").as_str(),
@ -378,8 +402,8 @@ mod tests {
}); });
} }
#[test]
fn multiplier_2() {
#[tokio::test]
async fn multiplier_2() {
run_test(TestCase { run_test(TestCase {
circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(), circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input2.json").as_str(), inputs_path: root_path("test-vectors/mycircuit-input2.json").as_str(),
@ -394,8 +418,8 @@ mod tests {
}); });
} }
#[test]
fn multiplier_3() {
#[tokio::test]
async fn multiplier_3() {
run_test(TestCase { run_test(TestCase {
circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(), circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input3.json").as_str(), inputs_path: root_path("test-vectors/mycircuit-input3.json").as_str(),
@ -410,8 +434,8 @@ mod tests {
}); });
} }
#[test]
fn safe_multipler() {
#[tokio::test]
async fn safe_multipler() {
let witness = let witness =
std::fs::read_to_string(root_path("test-vectors/safe-circuit-witness.json")).unwrap(); std::fs::read_to_string(root_path("test-vectors/safe-circuit-witness.json")).unwrap();
let witness: Vec<String> = serde_json::from_str(&witness).unwrap(); let witness: Vec<String> = serde_json::from_str(&witness).unwrap();
@ -425,8 +449,8 @@ mod tests {
}); });
} }
#[test]
fn smt_verifier() {
#[tokio::test]
async fn smt_verifier() {
let witness = let witness =
std::fs::read_to_string(root_path("test-vectors/smtverifier10-witness.json")).unwrap(); std::fs::read_to_string(root_path("test-vectors/smtverifier10-witness.json")).unwrap();
let witness: Vec<String> = serde_json::from_str(&witness).unwrap(); let witness: Vec<String> = serde_json::from_str(&witness).unwrap();
@ -453,12 +477,16 @@ mod tests {
} }
fn run_test(case: TestCase) { fn run_test(case: TestCase) {
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
let mut store = Store::default();
let mut wtns = WitnessCalculator::new(&mut store, case.circuit_path).unwrap();
assert_eq!( assert_eq!(
wtns.prime.to_str_radix(16), wtns.prime.to_str_radix(16),
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase() "30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
); );
assert_eq!({ wtns.instance.get_n_vars().unwrap() }, case.n_vars);
assert_eq!(
{ wtns.instance.get_n_vars(&mut store).unwrap() },
case.n_vars
);
assert_eq!({ wtns.n64 }, case.n64); assert_eq!({ wtns.n64 }, case.n64);
let inputs_str = std::fs::read_to_string(case.inputs_path).unwrap(); let inputs_str = std::fs::read_to_string(case.inputs_path).unwrap();
@ -483,7 +511,7 @@ mod tests {
}) })
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let res = wtns.calculate_witness(inputs, false).unwrap();
let res = wtns.calculate_witness(&mut store, inputs, false).unwrap();
for (r, w) in res.iter().zip(case.witness) { for (r, w) in res.iter().zip(case.witness) {
assert_eq!(r, &BigInt::from_str(w).unwrap()); assert_eq!(r, &BigInt::from_str(w).unwrap());
} }

+ 8
- 7
src/zkey.rs

@ -375,6 +375,7 @@ mod tests {
use num_bigint::BigUint; use num_bigint::BigUint;
use serde_json::Value; use serde_json::Value;
use std::fs::File; use std::fs::File;
use wasmer::Store;
use crate::circom::CircomReduction; use crate::circom::CircomReduction;
use crate::witness::WitnessCalculator; use crate::witness::WitnessCalculator;
@ -842,8 +843,8 @@ mod tests {
G2Affine::from(G2Projective::new(x, y, z)) G2Affine::from(G2Projective::new(x, y, z))
} }
#[test]
fn verify_proof_with_zkey_with_r1cs() {
#[tokio::test]
async fn verify_proof_with_zkey_with_r1cs() {
let path = "./test-vectors/test.zkey"; let path = "./test-vectors/test.zkey";
let mut file = File::open(path).unwrap(); let mut file = File::open(path).unwrap();
let (params, _matrices) = read_zkey(&mut file).unwrap(); // binfile.proving_key().unwrap(); let (params, _matrices) = read_zkey(&mut file).unwrap(); // binfile.proving_key().unwrap();
@ -871,13 +872,13 @@ mod tests {
assert!(verified); assert!(verified);
} }
#[test]
fn verify_proof_with_zkey_without_r1cs() {
#[tokio::test]
async fn verify_proof_with_zkey_without_r1cs() {
let path = "./test-vectors/test.zkey"; let path = "./test-vectors/test.zkey";
let mut file = File::open(path).unwrap(); let mut file = File::open(path).unwrap();
let (params, matrices) = read_zkey(&mut file).unwrap(); let (params, matrices) = read_zkey(&mut file).unwrap();
let mut wtns = WitnessCalculator::new("./test-vectors/mycircuit.wasm").unwrap();
let mut store = Store::default();
let mut wtns = WitnessCalculator::new(&mut store, "./test-vectors/mycircuit.wasm").unwrap();
let mut inputs: HashMap<String, Vec<num_bigint::BigInt>> = HashMap::new(); let mut inputs: HashMap<String, Vec<num_bigint::BigInt>> = HashMap::new();
let values = inputs.entry("a".to_string()).or_insert_with(Vec::new); let values = inputs.entry("a".to_string()).or_insert_with(Vec::new);
values.push(3.into()); values.push(3.into());
@ -895,7 +896,7 @@ mod tests {
let s = ark_bn254::Fr::rand(rng); let s = ark_bn254::Fr::rand(rng);
let full_assignment = wtns let full_assignment = wtns
.calculate_witness_element::<Bn254, _>(inputs, false)
.calculate_witness_element::<Bn254, _>(&mut store, inputs, false)
.unwrap(); .unwrap();
let proof = Groth16::<Bn254, CircomReduction>::create_proof_with_reduction_and_matrices( let proof = Groth16::<Bn254, CircomReduction>::create_proof_with_reduction_and_matrices(
&params, &params,

+ 8
- 8
tests/groth16.rs

@ -8,8 +8,8 @@ use ark_groth16::Groth16;
type GrothBn = Groth16<Bn254>; type GrothBn = Groth16<Bn254>;
#[test]
fn groth16_proof() -> Result<()> {
#[tokio::test]
async fn groth16_proof() -> Result<()> {
let cfg = CircomConfig::<Bn254>::new( let cfg = CircomConfig::<Bn254>::new(
"./test-vectors/mycircuit.wasm", "./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs", "./test-vectors/mycircuit.r1cs",
@ -39,8 +39,8 @@ fn groth16_proof() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn groth16_proof_wrong_input() {
#[tokio::test]
async fn groth16_proof_wrong_input() {
let cfg = CircomConfig::<Bn254>::new( let cfg = CircomConfig::<Bn254>::new(
"./test-vectors/mycircuit.wasm", "./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs", "./test-vectors/mycircuit.r1cs",
@ -60,9 +60,9 @@ fn groth16_proof_wrong_input() {
let _ = builder.build().unwrap_err(); let _ = builder.build().unwrap_err();
} }
#[test]
#[tokio::test]
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
fn groth16_proof_circom2() -> Result<()> {
async fn groth16_proof_circom2() -> Result<()> {
let cfg = CircomConfig::<Bn254>::new( let cfg = CircomConfig::<Bn254>::new(
"./test-vectors/circom2_multiplier2.wasm", "./test-vectors/circom2_multiplier2.wasm",
"./test-vectors/circom2_multiplier2.r1cs", "./test-vectors/circom2_multiplier2.r1cs",
@ -92,9 +92,9 @@ fn groth16_proof_circom2() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
#[tokio::test]
#[cfg(feature = "circom-2")] #[cfg(feature = "circom-2")]
fn witness_generation_circom2() -> Result<()> {
async fn witness_generation_circom2() -> Result<()> {
let cfg = CircomConfig::<Bn254>::new( let cfg = CircomConfig::<Bn254>::new(
"./test-vectors/circom2_multiplier2.wasm", "./test-vectors/circom2_multiplier2.wasm",
"./test-vectors/circom2_multiplier2.r1cs", "./test-vectors/circom2_multiplier2.r1cs",

Loading…
Cancel
Save