refactor: rename circom_wasm to witness

This commit is contained in:
Georgios Konstantopoulos
2021-07-26 17:32:18 +03:00
parent b64f038283
commit 8ff7f3cd1b
4 changed files with 10 additions and 13 deletions

78
src/witness/circom.rs Normal file
View File

@@ -0,0 +1,78 @@
use color_eyre::Result;
use wasmer::{Function, Instance, Value};
#[derive(Clone, Debug)]
pub struct Wasm(Instance);
impl Wasm {
pub fn new(instance: Instance) -> Self {
Self(instance)
}
pub fn init(&self, sanity_check: bool) -> Result<()> {
let func = self.func("init");
func.call(&[Value::I32(sanity_check as i32)])?;
Ok(())
}
pub fn get_fr_len(&self) -> Result<i32> {
self.get_i32("getFrLen")
}
pub fn get_ptr_raw_prime(&self) -> Result<i32> {
self.get_i32("getPRawPrime")
}
pub fn get_n_vars(&self) -> Result<i32> {
self.get_i32("getNVars")
}
pub fn get_ptr_witness_buffer(&self) -> Result<i32> {
self.get_i32("getWitnessBuffer")
}
pub fn get_ptr_witness(&self, w: i32) -> Result<i32> {
let func = self.func("getPWitness");
let res = func.call(&[w.into()])?;
Ok(res[0].unwrap_i32())
}
pub fn get_signal_offset32(
&self,
p_sig_offset: u32,
component: u32,
hash_msb: u32,
hash_lsb: u32,
) -> Result<()> {
let func = self.func("getSignalOffset32");
func.call(&[
p_sig_offset.into(),
component.into(),
hash_msb.into(),
hash_lsb.into(),
])?;
Ok(())
}
pub fn set_signal(&self, c_idx: i32, component: i32, signal: i32, p_val: i32) -> Result<()> {
let func = self.func("setSignal");
func.call(&[c_idx.into(), component.into(), signal.into(), p_val.into()])?;
Ok(())
}
fn get_i32(&self, name: &str) -> Result<i32> {
let func = self.func(name);
let result = func.call(&[])?;
Ok(result[0].unwrap_i32())
}
fn func(&self, name: &str) -> &Function {
self.0
.exports
.get_function(name)
.unwrap_or_else(|_| panic!("function {} not found", name))
}
}

271
src/witness/memory.rs Normal file
View File

@@ -0,0 +1,271 @@
//! Safe-ish interface for reading and writing specific types to the WASM runtime's memory
use num_traits::ToPrimitive;
use wasmer::{Memory, MemoryView};
// TODO: Decide whether we want Ark here or if it should use a generic BigInt package
use ark_bn254::FrParameters;
use ark_ff::{BigInteger, BigInteger256, FpParameters, FromBytes, Zero};
use num_bigint::{BigInt, BigUint};
use color_eyre::Result;
use std::str::FromStr;
use std::{convert::TryFrom, ops::Deref};
#[derive(Clone, Debug)]
pub struct SafeMemory {
pub memory: Memory,
pub prime: BigInt,
short_max: BigInt,
short_min: BigInt,
r_inv: BigInt,
n32: usize,
}
impl Deref for SafeMemory {
type Target = Memory;
fn deref(&self) -> &Self::Target {
&self.memory
}
}
impl SafeMemory {
/// Creates a new SafeMemory
pub fn new(memory: Memory, n32: usize, prime: BigInt) -> Self {
// TODO: Figure out a better way to calculate these
let short_max = BigInt::from(0x8000_0000u64);
let short_min = BigInt::from_biguint(
num_bigint::Sign::NoSign,
BigUint::try_from(FrParameters::MODULUS).unwrap(),
) - &short_max;
let r_inv = BigInt::from_str(
"9915499612839321149637521777990102151350674507940716049588462388200839649614",
)
.unwrap();
Self {
memory,
prime,
short_max,
short_min,
r_inv,
n32,
}
}
/// Gets an immutable view to the memory in 32 byte chunks
pub fn view(&self) -> MemoryView<u32> {
self.memory.view()
}
/// Returns the next free position in the memory
pub fn free_pos(&self) -> u32 {
self.view()[0].get()
}
/// Sets the next free position in the memory
pub fn set_free_pos(&mut self, ptr: u32) {
self.write_u32(0, ptr);
}
/// Allocates a U32 in memory
pub fn alloc_u32(&mut self) -> u32 {
let p = self.free_pos();
self.set_free_pos(p + 8);
p
}
/// Writes a u32 to the specified memory offset
pub fn write_u32(&mut self, ptr: usize, num: u32) {
let buf = unsafe { self.memory.data_unchecked_mut() };
buf[ptr..ptr + std::mem::size_of::<u32>()].copy_from_slice(&num.to_le_bytes());
}
/// Reads a u32 from the specified memory offset
pub fn read_u32(&self, ptr: usize) -> u32 {
let buf = unsafe { self.memory.data_unchecked() };
let mut bytes = [0; 4];
bytes.copy_from_slice(&buf[ptr..ptr + std::mem::size_of::<u32>()]);
u32::from_le_bytes(bytes)
}
/// Allocates `self.n32 * 4 + 8` bytes in the memory
pub fn alloc_fr(&mut self) -> u32 {
let p = self.free_pos();
self.set_free_pos(p + self.n32 as u32 * 4 + 8);
p
}
/// Writes a Field Element to memory at the specified offset, truncating
/// to smaller u32 types if needed and adjusting the sign via 2s complement
pub fn write_fr(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
if fr < &self.short_max && fr > &self.short_min {
if fr >= &BigInt::zero() {
self.write_short_positive(ptr, fr)?;
} else {
self.write_short_negative(ptr, fr)?;
}
} else {
self.write_long_normal(ptr, fr)?;
}
Ok(())
}
/// Reads a Field Element from the memory at the specified offset
pub fn read_fr(&self, ptr: usize) -> Result<BigInt> {
let view = self.memory.view::<u8>();
let res = if view[ptr + 4 + 3].get() & 0x80 != 0 {
let mut num = self.read_big(ptr + 8, self.n32)?;
if view[ptr + 4 + 3].get() & 0x40 != 0 {
num = (num * &self.r_inv) % &self.prime
}
num
} else if view[ptr + 3].get() & 0x40 != 0 {
let mut num = self.read_u32(ptr).into();
// handle small negative
num -= BigInt::from(0x100000000i64);
num
} else {
self.read_u32(ptr).into()
};
Ok(res)
}
fn write_short_positive(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
let num = fr.to_i32().expect("not a short positive");
self.write_u32(ptr, num as u32);
self.write_u32(ptr + 4, 0);
Ok(())
}
fn write_short_negative(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
// 2s complement
let num = fr - &self.short_min;
let num = num - &self.short_max;
let num = num + BigInt::from(0x0001_0000_0000i64);
let num = num
.to_u32()
.expect("could not cast as u32 (should never happen)");
self.write_u32(ptr, num);
self.write_u32(ptr + 4, 0);
Ok(())
}
fn write_long_normal(&mut self, ptr: usize, fr: &BigInt) -> Result<()> {
self.write_u32(ptr, 0);
self.write_u32(ptr + 4, i32::MIN as u32); // 0x80000000
self.write_big(ptr + 8, fr)?;
Ok(())
}
fn write_big(&self, ptr: usize, num: &BigInt) -> Result<()> {
let buf = unsafe { self.memory.data_unchecked_mut() };
// TODO: How do we handle negative bignums?
let (_, num) = num.clone().into_parts();
let num = BigInteger256::try_from(num).unwrap();
let bytes = num.to_bytes_le();
let len = bytes.len();
buf[ptr..ptr + len].copy_from_slice(&bytes);
Ok(())
}
/// Reads `num_bytes * 32` from the specified memory offset in a Big Integer
pub fn read_big(&self, ptr: usize, num_bytes: usize) -> Result<BigInt> {
let buf = unsafe { self.memory.data_unchecked() };
let buf = &buf[ptr..ptr + num_bytes * 32];
// TODO: Is there a better way to read big integers?
let big = BigInteger256::read(buf).unwrap();
let big = BigUint::try_from(big).unwrap();
Ok(big.into())
}
}
// TODO: Figure out how to read / write numbers > u32
// circom-witness-calculator: Wasm + Memory -> expose BigInts so that they can be consumed by any proof system
// ark-circom:
// 1. can read zkey
// 2. can generate witness from inputs
// 3. can generate proofs
// 4. can serialize proofs in the desired format
#[cfg(test)]
mod tests {
use super::*;
use num_traits::ToPrimitive;
use std::str::FromStr;
use wasmer::{MemoryType, Store};
fn new() -> SafeMemory {
SafeMemory::new(
Memory::new(&Store::default(), MemoryType::new(1, None, false)).unwrap(),
2,
BigInt::from_str(
"21888242871839275222246405745257275088548364400416034343698204186575808495617",
)
.unwrap(),
)
}
#[test]
fn i32_bounds() {
let mem = new();
let i32_max = i32::MAX as i64 + 1;
assert_eq!(mem.short_min.to_i64().unwrap(), -i32_max);
assert_eq!(mem.short_max.to_i64().unwrap(), i32_max);
}
#[test]
fn read_write_32() {
let mut mem = new();
let num = u32::MAX;
let inp = mem.read_u32(0);
assert_eq!(inp, 0);
mem.write_u32(0, num);
let inp = mem.read_u32(0);
assert_eq!(inp, num);
}
#[test]
fn read_write_fr_small_positive() {
read_write_fr(BigInt::from(1_000_000));
}
#[test]
fn read_write_fr_small_negative() {
read_write_fr(BigInt::from(-1_000_000));
}
#[test]
fn read_write_fr_big_positive() {
read_write_fr(BigInt::from(500000000000i64));
}
// TODO: How should this be handled?
#[test]
#[ignore]
fn read_write_fr_big_negative() {
read_write_fr(BigInt::from_str("-500000000000").unwrap())
}
fn read_write_fr(num: BigInt) {
let mut mem = new();
mem.write_fr(0, &num).unwrap();
let res = mem.read_fr(0).unwrap();
assert_eq!(res, num);
}
}

19
src/witness/mod.rs Normal file
View File

@@ -0,0 +1,19 @@
mod witness_calculator;
pub use witness_calculator::WitnessCalculator;
mod memory;
pub(super) use memory::SafeMemory;
mod circom;
pub(super) use circom::Wasm;
use fnv::FnvHasher;
use std::hash::Hasher;
pub(crate) fn fnv(inp: &str) -> (u32, u32) {
let mut hasher = FnvHasher::default();
hasher.write(inp.as_bytes());
let h = hasher.finish();
((h >> 32) as u32, h as u32)
}

View File

@@ -0,0 +1,274 @@
use color_eyre::Result;
use num_bigint::BigInt;
use num_traits::Zero;
use std::cell::Cell;
use wasmer::{imports, Function, Instance, Memory, MemoryType, Module, Store};
use super::{fnv, SafeMemory, Wasm};
#[derive(Clone, Debug)]
pub struct WitnessCalculator {
pub instance: Wasm,
pub memory: SafeMemory,
pub n64: i32,
}
impl WitnessCalculator {
pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self> {
let store = Store::default();
let module = Module::from_file(&store, path)?;
// Set up the memory
let memory = Memory::new(&store, MemoryType::new(2000, None, false)).unwrap();
let import_object = imports! {
"env" => {
"memory" => memory.clone(),
},
// Host function callbacks from the WASM
"runtime" => {
"error" => runtime::error(&store),
"logSetSignal" => runtime::log_signal(&store),
"logGetSignal" => runtime::log_signal(&store),
"logFinishComponent" => runtime::log_component(&store),
"logStartComponent" => runtime::log_component(&store),
"log" => runtime::log_component(&store),
}
};
let instance = Wasm::new(Instance::new(&module, &import_object)?);
let n32 = (instance.get_fr_len()? >> 2) - 2;
let mut memory = SafeMemory::new(memory, n32 as usize, BigInt::zero());
let ptr = instance.get_ptr_raw_prime()?;
let prime = memory.read_big(ptr as usize, n32 as usize)?;
let n64 = ((prime.bits() - 1) / 64 + 1) as i32;
memory.prime = prime;
Ok(WitnessCalculator {
instance,
memory,
n64,
})
}
pub fn calculate_witness<I: IntoIterator<Item = (String, Vec<BigInt>)>>(
&mut self,
inputs: I,
sanity_check: bool,
) -> Result<Vec<BigInt>> {
let old_mem_free_pos = self.memory.free_pos();
self.instance.init(sanity_check)?;
let p_sig_offset = self.memory.alloc_u32();
let p_fr = self.memory.alloc_fr();
// allocate the inputs
for (name, values) in inputs.into_iter() {
let (msb, lsb) = fnv(&name);
self.instance
.get_signal_offset32(p_sig_offset, 0, msb, lsb)?;
let sig_offset = self.memory.read_u32(p_sig_offset as usize) as usize;
for (i, value) in values.into_iter().enumerate() {
self.memory.write_fr(p_fr as usize, &value)?;
self.instance
.set_signal(0, 0, (sig_offset + i) as i32, p_fr as i32)?;
}
}
let mut w = Vec::new();
let n_vars = self.instance.get_n_vars()?;
for i in 0..n_vars {
let ptr = self.instance.get_ptr_witness(i)? as usize;
let el = self.memory.read_fr(ptr)?;
w.push(el);
}
self.memory.set_free_pos(old_mem_free_pos);
Ok(w)
}
pub fn get_witness_buffer(&self) -> Result<Vec<u8>> {
let ptr = self.instance.get_ptr_witness_buffer()? as usize;
let view = self.memory.memory.view::<u8>();
let len = self.instance.get_n_vars()? * self.n64 * 8;
let arr = view[ptr..ptr + len as usize]
.iter()
.map(Cell::get)
.collect::<Vec<_>>();
Ok(arr)
}
}
// callback hooks for debugging
mod runtime {
use super::*;
pub fn error(store: &Store) -> Function {
#[allow(unused)]
#[allow(clippy::many_single_char_names)]
fn func(a: i32, b: i32, c: i32, d: i32, e: i32, f: i32) {}
Function::new_native(&store, func)
}
pub fn log_signal(store: &Store) -> Function {
#[allow(unused)]
fn func(a: i32, b: i32) {}
Function::new_native(&store, func)
}
pub fn log_component(store: &Store) -> Function {
#[allow(unused)]
fn func(a: i32) {}
Function::new_native(&store, func)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{collections::HashMap, path::PathBuf};
struct TestCase<'a> {
circuit_path: &'a str,
inputs_path: &'a str,
n_vars: u32,
n64: u32,
witness: &'a [&'a str],
}
pub fn root_path(p: &str) -> String {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push(p);
path.to_string_lossy().to_string()
}
#[test]
fn multiplier_1() {
run_test(TestCase {
circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input1.json").as_str(),
n_vars: 4,
n64: 4,
witness: &["1", "33", "3", "11"],
});
}
#[test]
fn multiplier_2() {
run_test(TestCase {
circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input2.json").as_str(),
n_vars: 4,
n64: 4,
witness: &[
"1",
"21888242871839275222246405745257275088548364400416034343698204186575672693159",
"21888242871839275222246405745257275088548364400416034343698204186575796149939",
"11",
],
});
}
#[test]
fn multiplier_3() {
run_test(TestCase {
circuit_path: root_path("test-vectors/mycircuit.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input3.json").as_str(),
n_vars: 4,
n64: 4,
witness: &[
"1",
"21888242871839275222246405745257275088548364400416034343698204186575808493616",
"10944121435919637611123202872628637544274182200208017171849102093287904246808",
"2",
],
});
}
#[test]
fn safe_multipler() {
let witness =
std::fs::read_to_string(&root_path("test-vectors/safe-circuit-witness.json")).unwrap();
let witness: Vec<String> = serde_json::from_str(&witness).unwrap();
let witness = &witness.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
run_test(TestCase {
circuit_path: root_path("test-vectors/circuit2.wasm").as_str(),
inputs_path: root_path("test-vectors/mycircuit-input1.json").as_str(),
n_vars: 132, // 128 + 4
n64: 4,
witness,
});
}
#[test]
fn smt_verifier() {
let witness =
std::fs::read_to_string(&root_path("test-vectors/smtverifier10-witness.json")).unwrap();
let witness: Vec<String> = serde_json::from_str(&witness).unwrap();
let witness = &witness.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
run_test(TestCase {
circuit_path: root_path("test-vectors/smtverifier10.wasm").as_str(),
inputs_path: root_path("test-vectors/smtverifier10-input.json").as_str(),
n_vars: 4794,
n64: 4,
witness,
});
}
use serde_json::Value;
use std::str::FromStr;
fn value_to_bigint(v: Value) -> BigInt {
match v {
Value::String(inner) => BigInt::from_str(&inner).unwrap(),
Value::Number(inner) => BigInt::from(inner.as_u64().expect("not a u32")),
_ => panic!("unsupported type"),
}
}
fn run_test(case: TestCase) {
let mut wtns = WitnessCalculator::new(case.circuit_path).unwrap();
assert_eq!(
wtns.memory.prime.to_str_radix(16),
"30644E72E131A029B85045B68181585D2833E84879B9709143E1F593F0000001".to_lowercase()
);
assert_eq!(wtns.instance.get_n_vars().unwrap() as u32, case.n_vars);
assert_eq!(wtns.n64 as u32, case.n64);
let inputs_str = std::fs::read_to_string(case.inputs_path).unwrap();
let inputs: std::collections::HashMap<String, serde_json::Value> =
serde_json::from_str(&inputs_str).unwrap();
let inputs = inputs
.iter()
.map(|(key, value)| {
let res = match value {
Value::String(inner) => {
vec![BigInt::from_str(inner).unwrap()]
}
Value::Number(inner) => {
vec![BigInt::from(inner.as_u64().expect("not a u32"))]
}
Value::Array(inner) => inner.iter().cloned().map(value_to_bigint).collect(),
_ => panic!(),
};
(key.clone(), res)
})
.collect::<HashMap<_, _>>();
let res = wtns.calculate_witness(inputs, false).unwrap();
for (r, w) in res.iter().zip(case.witness) {
assert_eq!(r, &BigInt::from_str(w).unwrap());
}
}
}