mirror of
https://github.com/arnaucube/plonky2-u32.git
synced 2026-02-10 13:16:46 +01:00
first commit
This commit is contained in:
303
src/gadgets/arithmetic_u32.rs
Normal file
303
src/gadgets/arithmetic_u32.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator};
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::iop::witness::{PartitionWitness, Witness};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
use crate::gates::add_many_u32::U32AddManyGate;
|
||||
use crate::gates::arithmetic_u32::U32ArithmeticGate;
|
||||
use crate::gates::subtraction_u32::U32SubtractionGate;
|
||||
use crate::witness::GeneratedValuesU32;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct U32Target(pub Target);
|
||||
|
||||
pub trait CircuitBuilderU32<F: RichField + Extendable<D>, const D: usize> {
|
||||
fn add_virtual_u32_target(&mut self) -> U32Target;
|
||||
|
||||
fn add_virtual_u32_targets(&mut self, n: usize) -> Vec<U32Target>;
|
||||
|
||||
/// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits.
|
||||
fn constant_u32(&mut self, c: u32) -> U32Target;
|
||||
|
||||
fn zero_u32(&mut self) -> U32Target;
|
||||
|
||||
fn one_u32(&mut self) -> U32Target;
|
||||
|
||||
fn connect_u32(&mut self, x: U32Target, y: U32Target);
|
||||
|
||||
fn assert_zero_u32(&mut self, x: U32Target);
|
||||
|
||||
/// Checks for special cases where the value of
|
||||
/// `x * y + z`
|
||||
/// can be determined without adding a `U32ArithmeticGate`.
|
||||
fn arithmetic_u32_special_cases(
|
||||
&mut self,
|
||||
x: U32Target,
|
||||
y: U32Target,
|
||||
z: U32Target,
|
||||
) -> Option<(U32Target, U32Target)>;
|
||||
|
||||
// Returns x * y + z.
|
||||
fn mul_add_u32(&mut self, x: U32Target, y: U32Target, z: U32Target) -> (U32Target, U32Target);
|
||||
|
||||
fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target);
|
||||
|
||||
fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target);
|
||||
|
||||
fn add_u32s_with_carry(
|
||||
&mut self,
|
||||
to_add: &[U32Target],
|
||||
carry: U32Target,
|
||||
) -> (U32Target, U32Target);
|
||||
|
||||
fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target);
|
||||
|
||||
// Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x).
|
||||
fn sub_u32(&mut self, x: U32Target, y: U32Target, borrow: U32Target) -> (U32Target, U32Target);
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderU32<F, D>
|
||||
for CircuitBuilder<F, D>
|
||||
{
|
||||
fn add_virtual_u32_target(&mut self) -> U32Target {
|
||||
U32Target(self.add_virtual_target())
|
||||
}
|
||||
|
||||
fn add_virtual_u32_targets(&mut self, n: usize) -> Vec<U32Target> {
|
||||
self.add_virtual_targets(n)
|
||||
.into_iter()
|
||||
.map(U32Target)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns a U32Target for the value `c`, which is assumed to be at most 32 bits.
|
||||
fn constant_u32(&mut self, c: u32) -> U32Target {
|
||||
U32Target(self.constant(F::from_canonical_u32(c)))
|
||||
}
|
||||
|
||||
fn zero_u32(&mut self) -> U32Target {
|
||||
U32Target(self.zero())
|
||||
}
|
||||
|
||||
fn one_u32(&mut self) -> U32Target {
|
||||
U32Target(self.one())
|
||||
}
|
||||
|
||||
fn connect_u32(&mut self, x: U32Target, y: U32Target) {
|
||||
self.connect(x.0, y.0)
|
||||
}
|
||||
|
||||
fn assert_zero_u32(&mut self, x: U32Target) {
|
||||
self.assert_zero(x.0)
|
||||
}
|
||||
|
||||
/// Checks for special cases where the value of
|
||||
/// `x * y + z`
|
||||
/// can be determined without adding a `U32ArithmeticGate`.
|
||||
fn arithmetic_u32_special_cases(
|
||||
&mut self,
|
||||
x: U32Target,
|
||||
y: U32Target,
|
||||
z: U32Target,
|
||||
) -> Option<(U32Target, U32Target)> {
|
||||
let x_const = self.target_as_constant(x.0);
|
||||
let y_const = self.target_as_constant(y.0);
|
||||
let z_const = self.target_as_constant(z.0);
|
||||
|
||||
// If both terms are constant, return their (constant) sum.
|
||||
let first_term_const = if let (Some(xx), Some(yy)) = (x_const, y_const) {
|
||||
Some(xx * yy)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let (Some(a), Some(b)) = (first_term_const, z_const) {
|
||||
let sum = (a + b).to_canonical_u64();
|
||||
let (low, high) = (sum as u32, (sum >> 32) as u32);
|
||||
return Some((self.constant_u32(low), self.constant_u32(high)));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// Returns x * y + z.
|
||||
fn mul_add_u32(&mut self, x: U32Target, y: U32Target, z: U32Target) -> (U32Target, U32Target) {
|
||||
if let Some(result) = self.arithmetic_u32_special_cases(x, y, z) {
|
||||
return result;
|
||||
}
|
||||
|
||||
let gate = U32ArithmeticGate::<F, D>::new_from_config(&self.config);
|
||||
let (row, copy) = self.find_slot(gate, &[], &[]);
|
||||
|
||||
self.connect(Target::wire(row, gate.wire_ith_multiplicand_0(copy)), x.0);
|
||||
self.connect(Target::wire(row, gate.wire_ith_multiplicand_1(copy)), y.0);
|
||||
self.connect(Target::wire(row, gate.wire_ith_addend(copy)), z.0);
|
||||
|
||||
let output_low = U32Target(Target::wire(row, gate.wire_ith_output_low_half(copy)));
|
||||
let output_high = U32Target(Target::wire(row, gate.wire_ith_output_high_half(copy)));
|
||||
|
||||
(output_low, output_high)
|
||||
}
|
||||
|
||||
fn add_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
|
||||
let one = self.one_u32();
|
||||
self.mul_add_u32(a, one, b)
|
||||
}
|
||||
|
||||
fn add_many_u32(&mut self, to_add: &[U32Target]) -> (U32Target, U32Target) {
|
||||
match to_add.len() {
|
||||
0 => (self.zero_u32(), self.zero_u32()),
|
||||
1 => (to_add[0], self.zero_u32()),
|
||||
2 => self.add_u32(to_add[0], to_add[1]),
|
||||
_ => {
|
||||
let num_addends = to_add.len();
|
||||
let gate = U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends);
|
||||
let (row, copy) =
|
||||
self.find_slot(gate, &[F::from_canonical_usize(num_addends)], &[]);
|
||||
|
||||
for j in 0..num_addends {
|
||||
self.connect(
|
||||
Target::wire(row, gate.wire_ith_op_jth_addend(copy, j)),
|
||||
to_add[j].0,
|
||||
);
|
||||
}
|
||||
let zero = self.zero();
|
||||
self.connect(Target::wire(row, gate.wire_ith_carry(copy)), zero);
|
||||
|
||||
let output_low = U32Target(Target::wire(row, gate.wire_ith_output_result(copy)));
|
||||
let output_high = U32Target(Target::wire(row, gate.wire_ith_output_carry(copy)));
|
||||
|
||||
(output_low, output_high)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_u32s_with_carry(
|
||||
&mut self,
|
||||
to_add: &[U32Target],
|
||||
carry: U32Target,
|
||||
) -> (U32Target, U32Target) {
|
||||
if to_add.len() == 1 {
|
||||
return self.add_u32(to_add[0], carry);
|
||||
}
|
||||
|
||||
let num_addends = to_add.len();
|
||||
|
||||
let gate = U32AddManyGate::<F, D>::new_from_config(&self.config, num_addends);
|
||||
let (row, copy) = self.find_slot(gate, &[F::from_canonical_usize(num_addends)], &[]);
|
||||
|
||||
for j in 0..num_addends {
|
||||
self.connect(
|
||||
Target::wire(row, gate.wire_ith_op_jth_addend(copy, j)),
|
||||
to_add[j].0,
|
||||
);
|
||||
}
|
||||
self.connect(Target::wire(row, gate.wire_ith_carry(copy)), carry.0);
|
||||
|
||||
let output = U32Target(Target::wire(row, gate.wire_ith_output_result(copy)));
|
||||
let output_carry = U32Target(Target::wire(row, gate.wire_ith_output_carry(copy)));
|
||||
|
||||
(output, output_carry)
|
||||
}
|
||||
|
||||
fn mul_u32(&mut self, a: U32Target, b: U32Target) -> (U32Target, U32Target) {
|
||||
let zero = self.zero_u32();
|
||||
self.mul_add_u32(a, b, zero)
|
||||
}
|
||||
|
||||
// Returns x - y - borrow, as a pair (result, borrow), where borrow is 0 or 1 depending on whether borrowing from the next digit is required (iff y + borrow > x).
|
||||
fn sub_u32(&mut self, x: U32Target, y: U32Target, borrow: U32Target) -> (U32Target, U32Target) {
|
||||
let gate = U32SubtractionGate::<F, D>::new_from_config(&self.config);
|
||||
let (row, copy) = self.find_slot(gate, &[], &[]);
|
||||
|
||||
self.connect(Target::wire(row, gate.wire_ith_input_x(copy)), x.0);
|
||||
self.connect(Target::wire(row, gate.wire_ith_input_y(copy)), y.0);
|
||||
self.connect(
|
||||
Target::wire(row, gate.wire_ith_input_borrow(copy)),
|
||||
borrow.0,
|
||||
);
|
||||
|
||||
let output_result = U32Target(Target::wire(row, gate.wire_ith_output_result(copy)));
|
||||
let output_borrow = U32Target(Target::wire(row, gate.wire_ith_output_borrow(copy)));
|
||||
|
||||
(output_result, output_borrow)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SplitToU32Generator<F: RichField + Extendable<D>, const D: usize> {
|
||||
x: Target,
|
||||
low: U32Target,
|
||||
high: U32Target,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for SplitToU32Generator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
vec![self.x]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let x = witness.get_target(self.x);
|
||||
let x_u64 = x.to_canonical_u64();
|
||||
let low = x_u64 as u32;
|
||||
let high = (x_u64 >> 32) as u32;
|
||||
|
||||
out_buffer.set_u32_target(self.low, low);
|
||||
out_buffer.set_u32_target(self.high, high);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use plonky2::iop::witness::PartialWitness;
|
||||
use plonky2::plonk::circuit_data::CircuitConfig;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn test_add_many_u32s() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
|
||||
const NUM_ADDENDS: usize = 15;
|
||||
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let mut rng = OsRng;
|
||||
let mut to_add = Vec::new();
|
||||
let mut sum = 0u64;
|
||||
for _ in 0..NUM_ADDENDS {
|
||||
let x: u32 = rng.gen();
|
||||
sum += x as u64;
|
||||
to_add.push(builder.constant_u32(x));
|
||||
}
|
||||
let carry = builder.zero_u32();
|
||||
let (result_low, result_high) = builder.add_u32s_with_carry(&to_add, carry);
|
||||
let expected_low = builder.constant_u32((sum % (1 << 32)) as u32);
|
||||
let expected_high = builder.constant_u32((sum >> 32) as u32);
|
||||
|
||||
builder.connect_u32(result_low, expected_low);
|
||||
builder.connect_u32(result_high, expected_high);
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
data.verify(proof)
|
||||
}
|
||||
}
|
||||
3
src/gadgets/mod.rs
Normal file
3
src/gadgets/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod arithmetic_u32;
|
||||
pub mod multiple_comparison;
|
||||
pub mod range_check;
|
||||
152
src/gadgets/multiple_comparison.rs
Normal file
152
src/gadgets/multiple_comparison.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::target::{BoolTarget, Target};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::util::ceil_div_usize;
|
||||
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::gates::comparison::ComparisonGate;
|
||||
|
||||
/// Returns true if a is less than or equal to b, considered as base-`2^num_bits` limbs of a large value.
|
||||
/// This range-checks its inputs.
|
||||
pub fn list_le_circuit<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
a: Vec<Target>,
|
||||
b: Vec<Target>,
|
||||
num_bits: usize,
|
||||
) -> BoolTarget {
|
||||
assert_eq!(
|
||||
a.len(),
|
||||
b.len(),
|
||||
"Comparison must be between same number of inputs and outputs"
|
||||
);
|
||||
let n = a.len();
|
||||
|
||||
let chunk_bits = 2;
|
||||
let num_chunks = ceil_div_usize(num_bits, chunk_bits);
|
||||
|
||||
let one = builder.one();
|
||||
let mut result = one;
|
||||
for i in 0..n {
|
||||
let a_le_b_gate = ComparisonGate::new(num_bits, num_chunks);
|
||||
let a_le_b_row = builder.add_gate(a_le_b_gate.clone(), vec![]);
|
||||
builder.connect(
|
||||
Target::wire(a_le_b_row, a_le_b_gate.wire_first_input()),
|
||||
a[i],
|
||||
);
|
||||
builder.connect(
|
||||
Target::wire(a_le_b_row, a_le_b_gate.wire_second_input()),
|
||||
b[i],
|
||||
);
|
||||
let a_le_b_result = Target::wire(a_le_b_row, a_le_b_gate.wire_result_bool());
|
||||
|
||||
let b_le_a_gate = ComparisonGate::new(num_bits, num_chunks);
|
||||
let b_le_a_row = builder.add_gate(b_le_a_gate.clone(), vec![]);
|
||||
builder.connect(
|
||||
Target::wire(b_le_a_row, b_le_a_gate.wire_first_input()),
|
||||
b[i],
|
||||
);
|
||||
builder.connect(
|
||||
Target::wire(b_le_a_row, b_le_a_gate.wire_second_input()),
|
||||
a[i],
|
||||
);
|
||||
let b_le_a_result = Target::wire(b_le_a_row, b_le_a_gate.wire_result_bool());
|
||||
|
||||
let these_limbs_equal = builder.mul(a_le_b_result, b_le_a_result);
|
||||
let these_limbs_less_than = builder.sub(one, b_le_a_result);
|
||||
result = builder.mul_add(these_limbs_equal, result, these_limbs_less_than);
|
||||
}
|
||||
|
||||
// `result` being boolean is an invariant, maintained because its new value is always
|
||||
// `x * result + y`, where `x` and `y` are booleans that are not simultaneously true.
|
||||
BoolTarget::new_unsafe(result)
|
||||
}
|
||||
|
||||
/// Helper function for comparing, specifically, lists of `U32Target`s.
|
||||
pub fn list_le_u32_circuit<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
a: Vec<U32Target>,
|
||||
b: Vec<U32Target>,
|
||||
) -> BoolTarget {
|
||||
let a_targets: Vec<Target> = a.iter().map(|&t| t.0).collect();
|
||||
let b_targets: Vec<Target> = b.iter().map(|&t| t.0).collect();
|
||||
|
||||
list_le_circuit(builder, a_targets, b_targets, 32)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use num::BigUint;
|
||||
use plonky2::field::types::Field;
|
||||
use plonky2::iop::witness::PartialWitness;
|
||||
use plonky2::plonk::circuit_data::CircuitConfig;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_list_le(size: usize, num_bits: usize) -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
let config = CircuitConfig::standard_recursion_config();
|
||||
let pw = PartialWitness::new();
|
||||
let mut builder = CircuitBuilder::<F, D>::new(config);
|
||||
|
||||
let mut rng = OsRng;
|
||||
|
||||
let lst1: Vec<u64> = (0..size)
|
||||
.map(|_| rng.gen_range(0..(1 << num_bits)))
|
||||
.collect();
|
||||
let lst2: Vec<u64> = (0..size)
|
||||
.map(|_| rng.gen_range(0..(1 << num_bits)))
|
||||
.collect();
|
||||
|
||||
let a_biguint = BigUint::from_slice(
|
||||
&lst1
|
||||
.iter()
|
||||
.flat_map(|&x| [x as u32, (x >> 32) as u32])
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let b_biguint = BigUint::from_slice(
|
||||
&lst2
|
||||
.iter()
|
||||
.flat_map(|&x| [x as u32, (x >> 32) as u32])
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let a = lst1
|
||||
.iter()
|
||||
.map(|&x| builder.constant(F::from_canonical_u64(x)))
|
||||
.collect();
|
||||
let b = lst2
|
||||
.iter()
|
||||
.map(|&x| builder.constant(F::from_canonical_u64(x)))
|
||||
.collect();
|
||||
|
||||
let result = list_le_circuit(&mut builder, a, b, num_bits);
|
||||
|
||||
let expected_result = builder.constant_bool(a_biguint <= b_biguint);
|
||||
builder.connect(result.target, expected_result.target);
|
||||
|
||||
let data = builder.build::<C>();
|
||||
let proof = data.prove(pw).unwrap();
|
||||
data.verify(proof)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_comparison() -> Result<()> {
|
||||
for size in [1, 3, 6] {
|
||||
for num_bits in [20, 32, 40, 44] {
|
||||
test_list_le(size, num_bits).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
23
src/gadgets/range_check.rs
Normal file
23
src/gadgets/range_check.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
use crate::gates::range_check_u32::U32RangeCheckGate;
|
||||
|
||||
pub fn range_check_u32_circuit<F: RichField + Extendable<D>, const D: usize>(
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vals: Vec<U32Target>,
|
||||
) {
|
||||
let num_input_limbs = vals.len();
|
||||
let gate = U32RangeCheckGate::<F, D>::new(num_input_limbs);
|
||||
let row = builder.add_gate(gate, vec![]);
|
||||
|
||||
for i in 0..num_input_limbs {
|
||||
builder.connect(Target::wire(row, gate.wire_ith_input_limb(i)), vals[i].0);
|
||||
}
|
||||
}
|
||||
456
src/gates/add_many_u32.rs
Normal file
456
src/gates/add_many_u32.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
use alloc::boxed::Box;
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use itertools::unfold;
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::field::types::Field;
|
||||
use plonky2::gates::gate::Gate;
|
||||
use plonky2::gates::util::StridedConstraintConsumer;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::ext_target::ExtensionTarget;
|
||||
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::iop::wire::Wire;
|
||||
use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::plonk::circuit_data::CircuitConfig;
|
||||
use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use plonky2::util::ceil_div_usize;
|
||||
|
||||
const LOG2_MAX_NUM_ADDENDS: usize = 4;
|
||||
const MAX_NUM_ADDENDS: usize = 16;
|
||||
|
||||
/// A gate to perform addition on `num_addends` different 32-bit values, plus a small carry
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct U32AddManyGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub num_addends: usize,
|
||||
pub num_ops: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> U32AddManyGate<F, D> {
|
||||
pub fn new_from_config(config: &CircuitConfig, num_addends: usize) -> Self {
|
||||
Self {
|
||||
num_addends,
|
||||
num_ops: Self::num_ops(num_addends, config),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn num_ops(num_addends: usize, config: &CircuitConfig) -> usize {
|
||||
debug_assert!(num_addends <= MAX_NUM_ADDENDS);
|
||||
let wires_per_op = (num_addends + 3) + Self::num_limbs();
|
||||
let routed_wires_per_op = num_addends + 3;
|
||||
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
|
||||
}
|
||||
|
||||
pub fn wire_ith_op_jth_addend(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
debug_assert!(j < self.num_addends);
|
||||
(self.num_addends + 3) * i + j
|
||||
}
|
||||
pub fn wire_ith_carry(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
(self.num_addends + 3) * i + self.num_addends
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_result(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
(self.num_addends + 3) * i + self.num_addends + 1
|
||||
}
|
||||
pub fn wire_ith_output_carry(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
(self.num_addends + 3) * i + self.num_addends + 2
|
||||
}
|
||||
|
||||
pub fn limb_bits() -> usize {
|
||||
2
|
||||
}
|
||||
pub fn num_result_limbs() -> usize {
|
||||
ceil_div_usize(32, Self::limb_bits())
|
||||
}
|
||||
pub fn num_carry_limbs() -> usize {
|
||||
ceil_div_usize(LOG2_MAX_NUM_ADDENDS, Self::limb_bits())
|
||||
}
|
||||
pub fn num_limbs() -> usize {
|
||||
Self::num_result_limbs() + Self::num_carry_limbs()
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
debug_assert!(j < Self::num_limbs());
|
||||
(self.num_addends + 3) * self.num_ops + Self::num_limbs() * i + j
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32AddManyGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{self:?}")
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let addends: Vec<F::Extension> = (0..self.num_addends)
|
||||
.map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)])
|
||||
.collect();
|
||||
let carry = vars.local_wires[self.wire_ith_carry(i)];
|
||||
|
||||
let computed_output = addends.iter().fold(F::Extension::ZERO, |x, &y| x + y) + carry;
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_carry = vars.local_wires[self.wire_ith_output_carry(i)];
|
||||
|
||||
let base = F::Extension::from_canonical_u64(1 << 32u64);
|
||||
let combined_output = output_carry * base + output_result;
|
||||
|
||||
constraints.push(combined_output - computed_output);
|
||||
|
||||
let mut combined_result_limbs = F::Extension::ZERO;
|
||||
let mut combined_carry_limbs = F::Extension::ZERO;
|
||||
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
if j < Self::num_result_limbs() {
|
||||
combined_result_limbs = base * combined_result_limbs + this_limb;
|
||||
} else {
|
||||
combined_carry_limbs = base * combined_carry_limbs + this_limb;
|
||||
}
|
||||
}
|
||||
constraints.push(combined_result_limbs - output_result);
|
||||
constraints.push(combined_carry_limbs - output_carry);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
for i in 0..self.num_ops {
|
||||
let addends: Vec<F> = (0..self.num_addends)
|
||||
.map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)])
|
||||
.collect();
|
||||
let carry = vars.local_wires[self.wire_ith_carry(i)];
|
||||
|
||||
let computed_output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry;
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_carry = vars.local_wires[self.wire_ith_output_carry(i)];
|
||||
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let combined_output = output_carry * base + output_result;
|
||||
|
||||
yield_constr.one(combined_output - computed_output);
|
||||
|
||||
let mut combined_result_limbs = F::ZERO;
|
||||
let mut combined_carry_limbs = F::ZERO;
|
||||
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
.product();
|
||||
yield_constr.one(product);
|
||||
|
||||
if j < Self::num_result_limbs() {
|
||||
combined_result_limbs = base * combined_result_limbs + this_limb;
|
||||
} else {
|
||||
combined_carry_limbs = base * combined_carry_limbs + this_limb;
|
||||
}
|
||||
}
|
||||
yield_constr.one(combined_result_limbs - output_result);
|
||||
yield_constr.one(combined_carry_limbs - output_carry);
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_unfiltered_circuit(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
for i in 0..self.num_ops {
|
||||
let addends: Vec<ExtensionTarget<D>> = (0..self.num_addends)
|
||||
.map(|j| vars.local_wires[self.wire_ith_op_jth_addend(i, j)])
|
||||
.collect();
|
||||
let carry = vars.local_wires[self.wire_ith_carry(i)];
|
||||
|
||||
let mut computed_output = carry;
|
||||
for addend in addends {
|
||||
computed_output = builder.add_extension(computed_output, addend);
|
||||
}
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_carry = vars.local_wires[self.wire_ith_output_carry(i)];
|
||||
|
||||
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
|
||||
let base_target = builder.constant_extension(base);
|
||||
let combined_output =
|
||||
builder.mul_add_extension(output_carry, base_target, output_result);
|
||||
|
||||
constraints.push(builder.sub_extension(combined_output, computed_output));
|
||||
|
||||
let mut combined_result_limbs = builder.zero_extension();
|
||||
let mut combined_carry_limbs = builder.zero_extension();
|
||||
let base = builder
|
||||
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
|
||||
let mut product = builder.one_extension();
|
||||
for x in 0..max_limb {
|
||||
let x_target =
|
||||
builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(this_limb, x_target);
|
||||
product = builder.mul_extension(product, diff);
|
||||
}
|
||||
constraints.push(product);
|
||||
|
||||
if j < Self::num_result_limbs() {
|
||||
combined_result_limbs =
|
||||
builder.mul_add_extension(base, combined_result_limbs, this_limb);
|
||||
} else {
|
||||
combined_carry_limbs =
|
||||
builder.mul_add_extension(base, combined_carry_limbs, this_limb);
|
||||
}
|
||||
}
|
||||
constraints.push(builder.sub_extension(combined_result_limbs, output_result));
|
||||
constraints.push(builder.sub_extension(combined_carry_limbs, output_carry));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(&self, row: usize, _local_constants: &[F]) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
U32AddManyGenerator {
|
||||
gate: *self,
|
||||
row,
|
||||
i,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
.adapter(),
|
||||
);
|
||||
g
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
(self.num_addends + 3) * self.num_ops + Self::num_limbs() * self.num_ops
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1 << Self::limb_bits()
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_ops * (3 + Self::num_limbs())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct U32AddManyGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate: U32AddManyGate<F, D>,
|
||||
row: usize,
|
||||
i: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for U32AddManyGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |column| Target::wire(self.row, column);
|
||||
|
||||
(0..self.gate.num_addends)
|
||||
.map(|j| local_target(self.gate.wire_ith_op_jth_addend(self.i, j)))
|
||||
.chain([local_target(self.gate.wire_ith_carry(self.i))])
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |column| Wire {
|
||||
row: self.row,
|
||||
column,
|
||||
};
|
||||
|
||||
let get_local_wire = |column| witness.get_wire(local_wire(column));
|
||||
|
||||
let addends: Vec<_> = (0..self.gate.num_addends)
|
||||
.map(|j| get_local_wire(self.gate.wire_ith_op_jth_addend(self.i, j)))
|
||||
.collect();
|
||||
let carry = get_local_wire(self.gate.wire_ith_carry(self.i));
|
||||
|
||||
let output = addends.iter().fold(F::ZERO, |x, &y| x + y) + carry;
|
||||
let output_u64 = output.to_canonical_u64();
|
||||
|
||||
let output_carry_u64 = output_u64 >> 32;
|
||||
let output_result_u64 = output_u64 & ((1 << 32) - 1);
|
||||
|
||||
let output_carry = F::from_canonical_u64(output_carry_u64);
|
||||
let output_result = F::from_canonical_u64(output_result_u64);
|
||||
|
||||
let output_carry_wire = local_wire(self.gate.wire_ith_output_carry(self.i));
|
||||
let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i));
|
||||
|
||||
out_buffer.set_wire(output_carry_wire, output_carry);
|
||||
out_buffer.set_wire(output_result_wire, output_result);
|
||||
|
||||
let num_result_limbs = U32AddManyGate::<F, D>::num_result_limbs();
|
||||
let num_carry_limbs = U32AddManyGate::<F, D>::num_carry_limbs();
|
||||
let limb_base = 1 << U32AddManyGate::<F, D>::limb_bits();
|
||||
|
||||
let split_to_limbs = |mut val, num| {
|
||||
unfold((), move |_| {
|
||||
let ret = val % limb_base;
|
||||
val /= limb_base;
|
||||
Some(ret)
|
||||
})
|
||||
.take(num)
|
||||
.map(F::from_canonical_u64)
|
||||
};
|
||||
|
||||
let result_limbs = split_to_limbs(output_result_u64, num_result_limbs);
|
||||
let carry_limbs = split_to_limbs(output_carry_u64, num_carry_limbs);
|
||||
|
||||
for (j, limb) in result_limbs.chain(carry_limbs).enumerate() {
|
||||
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
|
||||
out_buffer.set_wire(wire, limb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use plonky2::field::extension::quartic::QuarticExtension;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::field::types::Sample;
|
||||
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use plonky2::hash::hash_types::HashOut;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(U32AddManyGate::<GoldilocksField, 4> {
|
||||
num_addends: 4,
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
test_eval_fns::<F, C, _, D>(U32AddManyGate::<GoldilocksField, D> {
|
||||
num_addends: 4,
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
type F = GoldilocksField;
|
||||
type FF = QuarticExtension<GoldilocksField>;
|
||||
const D: usize = 4;
|
||||
const NUM_ADDENDS: usize = 10;
|
||||
const NUM_U32_ADD_MANY_OPS: usize = 3;
|
||||
|
||||
fn get_wires(addends: Vec<Vec<u64>>, carries: Vec<u64>) -> Vec<FF> {
|
||||
let mut v0 = Vec::new();
|
||||
let mut v1 = Vec::new();
|
||||
|
||||
let num_result_limbs = U32AddManyGate::<F, D>::num_result_limbs();
|
||||
let num_carry_limbs = U32AddManyGate::<F, D>::num_carry_limbs();
|
||||
let limb_base = 1 << U32AddManyGate::<F, D>::limb_bits();
|
||||
for op in 0..NUM_U32_ADD_MANY_OPS {
|
||||
let adds = &addends[op];
|
||||
let ca = carries[op];
|
||||
|
||||
let output = adds.iter().sum::<u64>() + ca;
|
||||
let output_result = output & ((1 << 32) - 1);
|
||||
let output_carry = output >> 32;
|
||||
|
||||
let split_to_limbs = |mut val, num| {
|
||||
unfold((), move |_| {
|
||||
let ret = val % limb_base;
|
||||
val /= limb_base;
|
||||
Some(ret)
|
||||
})
|
||||
.take(num)
|
||||
.map(F::from_canonical_u64)
|
||||
};
|
||||
|
||||
let mut result_limbs: Vec<_> =
|
||||
split_to_limbs(output_result, num_result_limbs).collect();
|
||||
let mut carry_limbs: Vec<_> =
|
||||
split_to_limbs(output_carry, num_carry_limbs).collect();
|
||||
|
||||
for a in adds {
|
||||
v0.push(F::from_canonical_u64(*a));
|
||||
}
|
||||
v0.push(F::from_canonical_u64(ca));
|
||||
v0.push(F::from_canonical_u64(output_result));
|
||||
v0.push(F::from_canonical_u64(output_carry));
|
||||
v1.append(&mut result_limbs);
|
||||
v1.append(&mut carry_limbs);
|
||||
}
|
||||
|
||||
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
|
||||
}
|
||||
|
||||
let mut rng = OsRng;
|
||||
let addends: Vec<Vec<_>> = (0..NUM_U32_ADD_MANY_OPS)
|
||||
.map(|_| (0..NUM_ADDENDS).map(|_| rng.gen::<u32>() as u64).collect())
|
||||
.collect();
|
||||
let carries: Vec<_> = (0..NUM_U32_ADD_MANY_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
|
||||
let gate = U32AddManyGate::<F, D> {
|
||||
num_addends: NUM_ADDENDS,
|
||||
num_ops: NUM_U32_ADD_MANY_OPS,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(addends, carries),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
}
|
||||
575
src/gates/arithmetic_u32.rs
Normal file
575
src/gates/arithmetic_u32.rs
Normal file
@@ -0,0 +1,575 @@
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{format, vec};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use itertools::unfold;
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::field::packed::PackedField;
|
||||
use plonky2::field::types::Field;
|
||||
use plonky2::gates::gate::Gate;
|
||||
use plonky2::gates::packed_util::PackedEvaluableBase;
|
||||
use plonky2::gates::util::StridedConstraintConsumer;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::ext_target::ExtensionTarget;
|
||||
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::iop::wire::Wire;
|
||||
use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::plonk::circuit_data::CircuitConfig;
|
||||
use plonky2::plonk::vars::{
|
||||
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
|
||||
EvaluationVarsBasePacked,
|
||||
};
|
||||
|
||||
/// A gate to perform a basic mul-add on 32-bit values (we assume they are range-checked beforehand).
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct U32ArithmeticGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub num_ops: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> U32ArithmeticGate<F, D> {
|
||||
pub fn new_from_config(config: &CircuitConfig) -> Self {
|
||||
Self {
|
||||
num_ops: Self::num_ops(config),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
|
||||
let wires_per_op = Self::routed_wires_per_op() + Self::num_limbs();
|
||||
(config.num_wires / wires_per_op).min(config.num_routed_wires / Self::routed_wires_per_op())
|
||||
}
|
||||
|
||||
pub fn wire_ith_multiplicand_0(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
Self::routed_wires_per_op() * i
|
||||
}
|
||||
pub fn wire_ith_multiplicand_1(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
Self::routed_wires_per_op() * i + 1
|
||||
}
|
||||
pub fn wire_ith_addend(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
Self::routed_wires_per_op() * i + 2
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_low_half(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
Self::routed_wires_per_op() * i + 3
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_high_half(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
Self::routed_wires_per_op() * i + 4
|
||||
}
|
||||
|
||||
pub fn wire_ith_inverse(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
Self::routed_wires_per_op() * i + 5
|
||||
}
|
||||
|
||||
pub fn limb_bits() -> usize {
|
||||
2
|
||||
}
|
||||
pub fn num_limbs() -> usize {
|
||||
64 / Self::limb_bits()
|
||||
}
|
||||
pub fn routed_wires_per_op() -> usize {
|
||||
6
|
||||
}
|
||||
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
debug_assert!(j < Self::num_limbs());
|
||||
Self::routed_wires_per_op() * self.num_ops + Self::num_limbs() * i + j
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32ArithmeticGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{self:?}")
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[self.wire_ith_addend(i)];
|
||||
|
||||
let computed_output = multiplicand_0 * multiplicand_1 + addend;
|
||||
|
||||
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
|
||||
let inverse = vars.local_wires[self.wire_ith_inverse(i)];
|
||||
|
||||
// Check canonicity of combined_output = output_high * 2^32 + output_low
|
||||
let combined_output = {
|
||||
let base = F::Extension::from_canonical_u64(1 << 32u64);
|
||||
let one = F::Extension::ONE;
|
||||
let u32_max = F::Extension::from_canonical_u32(u32::MAX);
|
||||
|
||||
// This is zero if and only if the high limb is `u32::MAX`.
|
||||
// u32::MAX - output_high
|
||||
let diff = u32_max - output_high;
|
||||
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
|
||||
// inverse * diff - 1
|
||||
let hi_not_max = inverse * diff - one;
|
||||
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
|
||||
// hi_not_max * limb_0_u32
|
||||
let hi_not_max_or_lo_zero = hi_not_max * output_low;
|
||||
|
||||
constraints.push(hi_not_max_or_lo_zero);
|
||||
|
||||
output_high * base + output_low
|
||||
};
|
||||
|
||||
constraints.push(combined_output - computed_output);
|
||||
|
||||
let mut combined_low_limbs = F::Extension::ZERO;
|
||||
let mut combined_high_limbs = F::Extension::ZERO;
|
||||
let midpoint = Self::num_limbs() / 2;
|
||||
let base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
if j < midpoint {
|
||||
combined_low_limbs = base * combined_low_limbs + this_limb;
|
||||
} else {
|
||||
combined_high_limbs = base * combined_high_limbs + this_limb;
|
||||
}
|
||||
}
|
||||
constraints.push(combined_low_limbs - output_low);
|
||||
constraints.push(combined_high_limbs - output_high);
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
_vars: EvaluationVarsBase<F>,
|
||||
_yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
panic!("use eval_unfiltered_base_packed instead");
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
|
||||
self.eval_unfiltered_base_batch_packed(vars_base)
|
||||
}
|
||||
|
||||
fn eval_unfiltered_circuit(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[self.wire_ith_addend(i)];
|
||||
|
||||
let computed_output = builder.mul_add_extension(multiplicand_0, multiplicand_1, addend);
|
||||
|
||||
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
|
||||
let inverse = vars.local_wires[self.wire_ith_inverse(i)];
|
||||
|
||||
// Check canonicity of combined_output = output_high * 2^32 + output_low
|
||||
let combined_output = {
|
||||
let base: F::Extension = F::from_canonical_u64(1 << 32u64).into();
|
||||
let base_target = builder.constant_extension(base);
|
||||
let one = builder.one_extension();
|
||||
let u32_max =
|
||||
builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX));
|
||||
|
||||
// This is zero if and only if the high limb is `u32::MAX`.
|
||||
let diff = builder.sub_extension(u32_max, output_high);
|
||||
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
|
||||
let hi_not_max = builder.mul_sub_extension(inverse, diff, one);
|
||||
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
|
||||
let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, output_low);
|
||||
|
||||
constraints.push(hi_not_max_or_lo_zero);
|
||||
|
||||
builder.mul_add_extension(output_high, base_target, output_low)
|
||||
};
|
||||
|
||||
constraints.push(builder.sub_extension(combined_output, computed_output));
|
||||
|
||||
let mut combined_low_limbs = builder.zero_extension();
|
||||
let mut combined_high_limbs = builder.zero_extension();
|
||||
let midpoint = Self::num_limbs() / 2;
|
||||
let base = builder
|
||||
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
|
||||
let mut product = builder.one_extension();
|
||||
for x in 0..max_limb {
|
||||
let x_target =
|
||||
builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(this_limb, x_target);
|
||||
product = builder.mul_extension(product, diff);
|
||||
}
|
||||
constraints.push(product);
|
||||
|
||||
if j < midpoint {
|
||||
combined_low_limbs =
|
||||
builder.mul_add_extension(base, combined_low_limbs, this_limb);
|
||||
} else {
|
||||
combined_high_limbs =
|
||||
builder.mul_add_extension(base, combined_high_limbs, this_limb);
|
||||
}
|
||||
}
|
||||
|
||||
constraints.push(builder.sub_extension(combined_low_limbs, output_low));
|
||||
constraints.push(builder.sub_extension(combined_high_limbs, output_high));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(&self, row: usize, _local_constants: &[F]) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
U32ArithmeticGenerator {
|
||||
gate: *self,
|
||||
row,
|
||||
i,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
.adapter(),
|
||||
);
|
||||
g
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.num_ops * (Self::routed_wires_per_op() + Self::num_limbs())
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1 << Self::limb_bits()
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_ops * (4 + Self::num_limbs())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D>
|
||||
for U32ArithmeticGate<F, D>
|
||||
{
|
||||
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
|
||||
&self,
|
||||
vars: EvaluationVarsBasePacked<P>,
|
||||
mut yield_constr: StridedConstraintConsumer<P>,
|
||||
) {
|
||||
for i in 0..self.num_ops {
|
||||
let multiplicand_0 = vars.local_wires[self.wire_ith_multiplicand_0(i)];
|
||||
let multiplicand_1 = vars.local_wires[self.wire_ith_multiplicand_1(i)];
|
||||
let addend = vars.local_wires[self.wire_ith_addend(i)];
|
||||
|
||||
let computed_output = multiplicand_0 * multiplicand_1 + addend;
|
||||
|
||||
let output_low = vars.local_wires[self.wire_ith_output_low_half(i)];
|
||||
let output_high = vars.local_wires[self.wire_ith_output_high_half(i)];
|
||||
let inverse = vars.local_wires[self.wire_ith_inverse(i)];
|
||||
|
||||
let combined_output = {
|
||||
let base = P::from(F::from_canonical_u64(1 << 32u64));
|
||||
let one = P::ONES;
|
||||
let u32_max = P::from(F::from_canonical_u32(u32::MAX));
|
||||
|
||||
// This is zero if and only if the high limb is `u32::MAX`.
|
||||
// u32::MAX - output_high
|
||||
let diff = u32_max - output_high;
|
||||
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
|
||||
// inverse * diff - 1
|
||||
let hi_not_max = inverse * diff - one;
|
||||
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
|
||||
// hi_not_max * limb_0_u32
|
||||
let hi_not_max_or_lo_zero = hi_not_max * output_low;
|
||||
|
||||
yield_constr.one(hi_not_max_or_lo_zero);
|
||||
|
||||
output_high * base + output_low
|
||||
};
|
||||
|
||||
yield_constr.one(combined_output - computed_output);
|
||||
|
||||
let mut combined_low_limbs = P::ZEROS;
|
||||
let mut combined_high_limbs = P::ZEROS;
|
||||
let midpoint = Self::num_limbs() / 2;
|
||||
let base = F::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
.product();
|
||||
yield_constr.one(product);
|
||||
|
||||
if j < midpoint {
|
||||
combined_low_limbs = combined_low_limbs * base + this_limb;
|
||||
} else {
|
||||
combined_high_limbs = combined_high_limbs * base + this_limb;
|
||||
}
|
||||
}
|
||||
yield_constr.one(combined_low_limbs - output_low);
|
||||
yield_constr.one(combined_high_limbs - output_high);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct U32ArithmeticGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate: U32ArithmeticGate<F, D>,
|
||||
row: usize,
|
||||
i: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for U32ArithmeticGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |column| Target::wire(self.row, column);
|
||||
|
||||
vec![
|
||||
local_target(self.gate.wire_ith_multiplicand_0(self.i)),
|
||||
local_target(self.gate.wire_ith_multiplicand_1(self.i)),
|
||||
local_target(self.gate.wire_ith_addend(self.i)),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |column| Wire {
|
||||
row: self.row,
|
||||
column,
|
||||
};
|
||||
|
||||
let get_local_wire = |column| witness.get_wire(local_wire(column));
|
||||
|
||||
let multiplicand_0 = get_local_wire(self.gate.wire_ith_multiplicand_0(self.i));
|
||||
let multiplicand_1 = get_local_wire(self.gate.wire_ith_multiplicand_1(self.i));
|
||||
let addend = get_local_wire(self.gate.wire_ith_addend(self.i));
|
||||
|
||||
let output = multiplicand_0 * multiplicand_1 + addend;
|
||||
let mut output_u64 = output.to_canonical_u64();
|
||||
|
||||
let output_high_u64 = output_u64 >> 32;
|
||||
let output_low_u64 = output_u64 & ((1 << 32) - 1);
|
||||
|
||||
let output_high = F::from_canonical_u64(output_high_u64);
|
||||
let output_low = F::from_canonical_u64(output_low_u64);
|
||||
|
||||
let output_high_wire = local_wire(self.gate.wire_ith_output_high_half(self.i));
|
||||
let output_low_wire = local_wire(self.gate.wire_ith_output_low_half(self.i));
|
||||
|
||||
out_buffer.set_wire(output_high_wire, output_high);
|
||||
out_buffer.set_wire(output_low_wire, output_low);
|
||||
|
||||
let diff = u32::MAX as u64 - output_high_u64;
|
||||
let inverse = if diff == 0 {
|
||||
F::ZERO
|
||||
} else {
|
||||
F::from_canonical_u64(diff).inverse()
|
||||
};
|
||||
let inverse_wire = local_wire(self.gate.wire_ith_inverse(self.i));
|
||||
out_buffer.set_wire(inverse_wire, inverse);
|
||||
|
||||
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << U32ArithmeticGate::<F, D>::limb_bits();
|
||||
let output_limbs_u64 = unfold((), move |_| {
|
||||
let ret = output_u64 % limb_base;
|
||||
output_u64 /= limb_base;
|
||||
Some(ret)
|
||||
})
|
||||
.take(num_limbs);
|
||||
let output_limbs_f = output_limbs_u64.map(F::from_canonical_u64);
|
||||
|
||||
for (j, output_limb) in output_limbs_f.enumerate() {
|
||||
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
|
||||
out_buffer.set_wire(wire, output_limb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::field::types::Sample;
|
||||
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use plonky2::hash::hash_types::HashOut;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(U32ArithmeticGate::<GoldilocksField, 4> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
test_eval_fns::<F, C, _, D>(U32ArithmeticGate::<GoldilocksField, D> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_wires<
|
||||
F: RichField + Extendable<D>,
|
||||
FF: From<F>,
|
||||
const D: usize,
|
||||
const NUM_U32_ARITHMETIC_OPS: usize,
|
||||
>(
|
||||
multiplicands_0: Vec<u64>,
|
||||
multiplicands_1: Vec<u64>,
|
||||
addends: Vec<u64>,
|
||||
) -> Vec<FF> {
|
||||
let mut v0 = Vec::new();
|
||||
let mut v1 = Vec::new();
|
||||
|
||||
let limb_bits = U32ArithmeticGate::<F, D>::limb_bits();
|
||||
let num_limbs = U32ArithmeticGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << limb_bits;
|
||||
for c in 0..NUM_U32_ARITHMETIC_OPS {
|
||||
let m0 = multiplicands_0[c];
|
||||
let m1 = multiplicands_1[c];
|
||||
let a = addends[c];
|
||||
|
||||
let mut output = m0 * m1 + a;
|
||||
let output_low = output & ((1 << 32) - 1);
|
||||
let output_high = output >> 32;
|
||||
let diff = u32::MAX as u64 - output_high;
|
||||
let inverse = if diff == 0 {
|
||||
F::ZERO
|
||||
} else {
|
||||
F::from_canonical_u64(diff).inverse()
|
||||
};
|
||||
|
||||
let mut output_limbs = Vec::with_capacity(num_limbs);
|
||||
for _i in 0..num_limbs {
|
||||
output_limbs.push(output % limb_base);
|
||||
output /= limb_base;
|
||||
}
|
||||
let mut output_limbs_f: Vec<_> = output_limbs
|
||||
.into_iter()
|
||||
.map(F::from_canonical_u64)
|
||||
.collect();
|
||||
|
||||
v0.push(F::from_canonical_u64(m0));
|
||||
v0.push(F::from_canonical_u64(m1));
|
||||
v0.push(F::from_noncanonical_u64(a));
|
||||
v0.push(F::from_canonical_u64(output_low));
|
||||
v0.push(F::from_canonical_u64(output_high));
|
||||
v0.push(inverse);
|
||||
v1.append(&mut output_limbs_f);
|
||||
}
|
||||
|
||||
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
type FF = <C as GenericConfig<D>>::FE;
|
||||
const NUM_U32_ARITHMETIC_OPS: usize = 3;
|
||||
|
||||
let mut rng = OsRng;
|
||||
let multiplicands_0: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
let multiplicands_1: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
let addends: Vec<_> = (0..NUM_U32_ARITHMETIC_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
|
||||
let gate = U32ArithmeticGate::<F, D> {
|
||||
num_ops: NUM_U32_ARITHMETIC_OPS,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires::<F, FF, D, NUM_U32_ARITHMETIC_OPS>(
|
||||
multiplicands_0,
|
||||
multiplicands_1,
|
||||
addends,
|
||||
),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_canonicity() {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
type FF = <C as GenericConfig<D>>::FE;
|
||||
const NUM_U32_ARITHMETIC_OPS: usize = 3;
|
||||
|
||||
let multiplicands_0 = vec![0; NUM_U32_ARITHMETIC_OPS];
|
||||
let multiplicands_1 = vec![0; NUM_U32_ARITHMETIC_OPS];
|
||||
// A non-canonical addend will produce a non-canonical output using
|
||||
// get_wires.
|
||||
let addends = vec![0xFFFFFFFF00000001; NUM_U32_ARITHMETIC_OPS];
|
||||
|
||||
let gate = U32ArithmeticGate::<F, D> {
|
||||
num_ops: NUM_U32_ARITHMETIC_OPS,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires::<F, FF, D, NUM_U32_ARITHMETIC_OPS>(
|
||||
multiplicands_0,
|
||||
multiplicands_1,
|
||||
addends,
|
||||
),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
!gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Non-canonical output should not pass constraints."
|
||||
);
|
||||
}
|
||||
}
|
||||
710
src/gates/comparison.rs
Normal file
710
src/gates/comparison.rs
Normal file
@@ -0,0 +1,710 @@
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{format, vec};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::field::packed::PackedField;
|
||||
use plonky2::field::types::{Field, Field64};
|
||||
use plonky2::gates::gate::Gate;
|
||||
use plonky2::gates::packed_util::PackedEvaluableBase;
|
||||
use plonky2::gates::util::StridedConstraintConsumer;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::ext_target::ExtensionTarget;
|
||||
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::iop::wire::Wire;
|
||||
use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit};
|
||||
use plonky2::plonk::vars::{
|
||||
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
|
||||
EvaluationVarsBasePacked,
|
||||
};
|
||||
use plonky2::util::{bits_u64, ceil_div_usize};
|
||||
|
||||
/// A gate for checking that one value is less than or equal to another.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ComparisonGate<F: Field64 + Extendable<D>, const D: usize> {
|
||||
pub(crate) num_bits: usize,
|
||||
pub(crate) num_chunks: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> ComparisonGate<F, D> {
|
||||
pub fn new(num_bits: usize, num_chunks: usize) -> Self {
|
||||
debug_assert!(num_bits < bits_u64(F::ORDER));
|
||||
Self {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunk_bits(&self) -> usize {
|
||||
ceil_div_usize(self.num_bits, self.num_chunks)
|
||||
}
|
||||
|
||||
pub fn wire_first_input(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn wire_second_input(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn wire_result_bool(&self) -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
pub fn wire_most_significant_diff(&self) -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
pub fn wire_first_chunk_val(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
4 + chunk
|
||||
}
|
||||
|
||||
pub fn wire_second_chunk_val(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
4 + self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_equality_dummy(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
4 + 2 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_chunks_equal(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
4 + 3 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
pub fn wire_intermediate_value(&self, chunk: usize) -> usize {
|
||||
debug_assert!(chunk < self.num_chunks);
|
||||
4 + 4 * self.num_chunks + chunk
|
||||
}
|
||||
|
||||
/// The `bit_index`th bit of 2^n - 1 + most_significant_diff.
|
||||
pub fn wire_most_significant_diff_bit(&self, bit_index: usize) -> usize {
|
||||
4 + 5 * self.num_chunks + bit_index
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for ComparisonGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{self:?}<D={D}>")
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
// Get chunks and assert that they match
|
||||
let first_chunks: Vec<F::Extension> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
|
||||
.collect();
|
||||
let second_chunks: Vec<F::Extension> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
|
||||
.collect();
|
||||
|
||||
let first_chunks_combined = reduce_with_powers(
|
||||
&first_chunks,
|
||||
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
let second_chunks_combined = reduce_with_powers(
|
||||
&second_chunks,
|
||||
F::Extension::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
|
||||
constraints.push(first_chunks_combined - first_input);
|
||||
constraints.push(second_chunks_combined - second_input);
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
let mut most_significant_diff_so_far = F::Extension::ZERO;
|
||||
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let first_product: F::Extension = (0..chunk_size)
|
||||
.map(|x| first_chunks[i] - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
let second_product: F::Extension = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
|
||||
let difference = second_chunks[i] - first_chunks[i];
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
constraints.push(difference * equality_dummy - (F::Extension::ONE - chunks_equal));
|
||||
constraints.push(chunks_equal * difference);
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
constraints.push(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
most_significant_diff_so_far =
|
||||
intermediate_value + (F::Extension::ONE - chunks_equal) * difference;
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints.push(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
let most_significant_diff_bits: Vec<F::Extension> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
.collect();
|
||||
|
||||
// Range-check the bits.
|
||||
for &bit in &most_significant_diff_bits {
|
||||
constraints.push(bit * (F::Extension::ONE - bit));
|
||||
}
|
||||
|
||||
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::Extension::TWO);
|
||||
let two_n = F::Extension::from_canonical_u64(1 << self.chunk_bits());
|
||||
constraints.push((two_n + most_significant_diff) - bits_combined);
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
constraints.push(result_bool - most_significant_diff_bits[self.chunk_bits()]);
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
_vars: EvaluationVarsBase<F>,
|
||||
_yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
panic!("use eval_unfiltered_base_packed instead");
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
|
||||
self.eval_unfiltered_base_batch_packed(vars_base)
|
||||
}
|
||||
|
||||
fn eval_unfiltered_circuit(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
// Get chunks and assert that they match
|
||||
let first_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
|
||||
.collect();
|
||||
let second_chunks: Vec<ExtensionTarget<D>> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
|
||||
.collect();
|
||||
|
||||
let chunk_base = builder.constant(F::from_canonical_usize(1 << self.chunk_bits()));
|
||||
let first_chunks_combined =
|
||||
reduce_with_powers_ext_circuit(builder, &first_chunks, chunk_base);
|
||||
let second_chunks_combined =
|
||||
reduce_with_powers_ext_circuit(builder, &second_chunks, chunk_base);
|
||||
|
||||
constraints.push(builder.sub_extension(first_chunks_combined, first_input));
|
||||
constraints.push(builder.sub_extension(second_chunks_combined, second_input));
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
let mut most_significant_diff_so_far = builder.zero_extension();
|
||||
|
||||
let one = builder.one_extension();
|
||||
// Find the chosen chunk.
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let mut first_product = one;
|
||||
let mut second_product = one;
|
||||
for x in 0..chunk_size {
|
||||
let x_f = builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let first_diff = builder.sub_extension(first_chunks[i], x_f);
|
||||
let second_diff = builder.sub_extension(second_chunks[i], x_f);
|
||||
first_product = builder.mul_extension(first_product, first_diff);
|
||||
second_product = builder.mul_extension(second_product, second_diff);
|
||||
}
|
||||
constraints.push(first_product);
|
||||
constraints.push(second_product);
|
||||
|
||||
let difference = builder.sub_extension(second_chunks[i], first_chunks[i]);
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
let diff_times_equal = builder.mul_extension(difference, equality_dummy);
|
||||
let not_equal = builder.sub_extension(one, chunks_equal);
|
||||
constraints.push(builder.sub_extension(diff_times_equal, not_equal));
|
||||
constraints.push(builder.mul_extension(chunks_equal, difference));
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
let old_diff = builder.mul_extension(chunks_equal, most_significant_diff_so_far);
|
||||
constraints.push(builder.sub_extension(intermediate_value, old_diff));
|
||||
|
||||
let not_equal = builder.sub_extension(one, chunks_equal);
|
||||
let new_diff = builder.mul_extension(not_equal, difference);
|
||||
most_significant_diff_so_far = builder.add_extension(intermediate_value, new_diff);
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
constraints
|
||||
.push(builder.sub_extension(most_significant_diff, most_significant_diff_so_far));
|
||||
|
||||
let most_significant_diff_bits: Vec<ExtensionTarget<D>> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
.collect();
|
||||
|
||||
// Range-check the bits.
|
||||
for &this_bit in &most_significant_diff_bits {
|
||||
let inverse = builder.sub_extension(one, this_bit);
|
||||
constraints.push(builder.mul_extension(this_bit, inverse));
|
||||
}
|
||||
|
||||
let two = builder.two();
|
||||
let bits_combined =
|
||||
reduce_with_powers_ext_circuit(builder, &most_significant_diff_bits, two);
|
||||
let two_n =
|
||||
builder.constant_extension(F::Extension::from_canonical_u64(1 << self.chunk_bits()));
|
||||
let sum = builder.add_extension(two_n, most_significant_diff);
|
||||
constraints.push(builder.sub_extension(sum, bits_combined));
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
constraints.push(
|
||||
builder.sub_extension(result_bool, most_significant_diff_bits[self.chunk_bits()]),
|
||||
);
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(&self, row: usize, _local_constants: &[F]) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = ComparisonGenerator::<F, D> {
|
||||
row,
|
||||
gate: self.clone(),
|
||||
};
|
||||
vec![Box::new(gen.adapter())]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
4 + 5 * self.num_chunks + (self.chunk_bits() + 1)
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1 << self.chunk_bits()
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
6 + 5 * self.num_chunks + self.chunk_bits()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D>
|
||||
for ComparisonGate<F, D>
|
||||
{
|
||||
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
|
||||
&self,
|
||||
vars: EvaluationVarsBasePacked<P>,
|
||||
mut yield_constr: StridedConstraintConsumer<P>,
|
||||
) {
|
||||
let first_input = vars.local_wires[self.wire_first_input()];
|
||||
let second_input = vars.local_wires[self.wire_second_input()];
|
||||
|
||||
// Get chunks and assert that they match
|
||||
let first_chunks: Vec<_> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_first_chunk_val(i)])
|
||||
.collect();
|
||||
let second_chunks: Vec<_> = (0..self.num_chunks)
|
||||
.map(|i| vars.local_wires[self.wire_second_chunk_val(i)])
|
||||
.collect();
|
||||
|
||||
let first_chunks_combined = reduce_with_powers(
|
||||
&first_chunks,
|
||||
F::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
let second_chunks_combined = reduce_with_powers(
|
||||
&second_chunks,
|
||||
F::from_canonical_usize(1 << self.chunk_bits()),
|
||||
);
|
||||
|
||||
yield_constr.one(first_chunks_combined - first_input);
|
||||
yield_constr.one(second_chunks_combined - second_input);
|
||||
|
||||
let chunk_size = 1 << self.chunk_bits();
|
||||
|
||||
let mut most_significant_diff_so_far = P::ZEROS;
|
||||
|
||||
for i in 0..self.num_chunks {
|
||||
// Range-check the chunks to be less than `chunk_size`.
|
||||
let first_product: P = (0..chunk_size)
|
||||
.map(|x| first_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
let second_product: P = (0..chunk_size)
|
||||
.map(|x| second_chunks[i] - F::from_canonical_usize(x))
|
||||
.product();
|
||||
yield_constr.one(first_product);
|
||||
yield_constr.one(second_product);
|
||||
|
||||
let difference = second_chunks[i] - first_chunks[i];
|
||||
let equality_dummy = vars.local_wires[self.wire_equality_dummy(i)];
|
||||
let chunks_equal = vars.local_wires[self.wire_chunks_equal(i)];
|
||||
|
||||
// Two constraints to assert that `chunks_equal` is valid.
|
||||
yield_constr.one(difference * equality_dummy - (P::ONES - chunks_equal));
|
||||
yield_constr.one(chunks_equal * difference);
|
||||
|
||||
// Update `most_significant_diff_so_far`.
|
||||
let intermediate_value = vars.local_wires[self.wire_intermediate_value(i)];
|
||||
yield_constr.one(intermediate_value - chunks_equal * most_significant_diff_so_far);
|
||||
most_significant_diff_so_far =
|
||||
intermediate_value + (P::ONES - chunks_equal) * difference;
|
||||
}
|
||||
|
||||
let most_significant_diff = vars.local_wires[self.wire_most_significant_diff()];
|
||||
yield_constr.one(most_significant_diff - most_significant_diff_so_far);
|
||||
|
||||
let most_significant_diff_bits: Vec<_> = (0..self.chunk_bits() + 1)
|
||||
.map(|i| vars.local_wires[self.wire_most_significant_diff_bit(i)])
|
||||
.collect();
|
||||
|
||||
// Range-check the bits.
|
||||
for &bit in &most_significant_diff_bits {
|
||||
yield_constr.one(bit * (P::ONES - bit));
|
||||
}
|
||||
|
||||
let bits_combined = reduce_with_powers(&most_significant_diff_bits, F::TWO);
|
||||
let two_n = F::from_canonical_u64(1 << self.chunk_bits());
|
||||
yield_constr.one((most_significant_diff + two_n) - bits_combined);
|
||||
|
||||
// Iff first <= second, the top (n + 1st) bit of (2^n - 1 + most_significant_diff) will be 1.
|
||||
let result_bool = vars.local_wires[self.wire_result_bool()];
|
||||
yield_constr.one(result_bool - most_significant_diff_bits[self.chunk_bits()]);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ComparisonGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
row: usize,
|
||||
gate: ComparisonGate<F, D>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for ComparisonGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |column| Target::wire(self.row, column);
|
||||
|
||||
vec![
|
||||
local_target(self.gate.wire_first_input()),
|
||||
local_target(self.gate.wire_second_input()),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |column| Wire {
|
||||
row: self.row,
|
||||
column,
|
||||
};
|
||||
|
||||
let get_local_wire = |column| witness.get_wire(local_wire(column));
|
||||
|
||||
let first_input = get_local_wire(self.gate.wire_first_input());
|
||||
let second_input = get_local_wire(self.gate.wire_second_input());
|
||||
|
||||
let first_input_u64 = first_input.to_canonical_u64();
|
||||
let second_input_u64 = second_input.to_canonical_u64();
|
||||
|
||||
let result = F::from_canonical_usize((first_input_u64 <= second_input_u64) as usize);
|
||||
|
||||
let chunk_size = 1 << self.gate.chunk_bits();
|
||||
let first_input_chunks: Vec<F> = (0..self.gate.num_chunks)
|
||||
.scan(first_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
let second_input_chunks: Vec<F> = (0..self.gate.num_chunks)
|
||||
.scan(second_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let chunks_equal: Vec<F> = (0..self.gate.num_chunks)
|
||||
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
|
||||
.collect();
|
||||
let equality_dummies: Vec<F> = first_input_chunks
|
||||
.iter()
|
||||
.zip(second_input_chunks.iter())
|
||||
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
|
||||
.collect();
|
||||
|
||||
let mut most_significant_diff_so_far = F::ZERO;
|
||||
let mut intermediate_values = Vec::new();
|
||||
for i in 0..self.gate.num_chunks {
|
||||
if first_input_chunks[i] != second_input_chunks[i] {
|
||||
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
|
||||
intermediate_values.push(F::ZERO);
|
||||
} else {
|
||||
intermediate_values.push(most_significant_diff_so_far);
|
||||
}
|
||||
}
|
||||
let most_significant_diff = most_significant_diff_so_far;
|
||||
|
||||
let two_n = F::from_canonical_usize(1 << self.gate.chunk_bits());
|
||||
let two_n_plus_msd = (two_n + most_significant_diff).to_canonical_u64();
|
||||
|
||||
let msd_bits_u64: Vec<u64> = (0..self.gate.chunk_bits() + 1)
|
||||
.scan(two_n_plus_msd, |acc, _| {
|
||||
let tmp = *acc % 2;
|
||||
*acc /= 2;
|
||||
Some(tmp)
|
||||
})
|
||||
.collect();
|
||||
let msd_bits: Vec<F> = msd_bits_u64
|
||||
.iter()
|
||||
.map(|x| F::from_canonical_u64(*x))
|
||||
.collect();
|
||||
|
||||
out_buffer.set_wire(local_wire(self.gate.wire_result_bool()), result);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_most_significant_diff()),
|
||||
most_significant_diff,
|
||||
);
|
||||
for i in 0..self.gate.num_chunks {
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_first_chunk_val(i)),
|
||||
first_input_chunks[i],
|
||||
);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_second_chunk_val(i)),
|
||||
second_input_chunks[i],
|
||||
);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_equality_dummy(i)),
|
||||
equality_dummies[i],
|
||||
);
|
||||
out_buffer.set_wire(local_wire(self.gate.wire_chunks_equal(i)), chunks_equal[i]);
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_intermediate_value(i)),
|
||||
intermediate_values[i],
|
||||
);
|
||||
}
|
||||
for i in 0..self.gate.chunk_bits() + 1 {
|
||||
out_buffer.set_wire(
|
||||
local_wire(self.gate.wire_most_significant_diff_bit(i)),
|
||||
msd_bits[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::field::types::{PrimeField64, Sample};
|
||||
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use plonky2::hash::hash_types::HashOut;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn wire_indices() {
|
||||
type CG = ComparisonGate<GoldilocksField, 4>;
|
||||
let num_bits = 40;
|
||||
let num_chunks = 5;
|
||||
|
||||
let gate = CG {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
assert_eq!(gate.wire_first_input(), 0);
|
||||
assert_eq!(gate.wire_second_input(), 1);
|
||||
assert_eq!(gate.wire_result_bool(), 2);
|
||||
assert_eq!(gate.wire_most_significant_diff(), 3);
|
||||
assert_eq!(gate.wire_first_chunk_val(0), 4);
|
||||
assert_eq!(gate.wire_first_chunk_val(4), 8);
|
||||
assert_eq!(gate.wire_second_chunk_val(0), 9);
|
||||
assert_eq!(gate.wire_second_chunk_val(4), 13);
|
||||
assert_eq!(gate.wire_equality_dummy(0), 14);
|
||||
assert_eq!(gate.wire_equality_dummy(4), 18);
|
||||
assert_eq!(gate.wire_chunks_equal(0), 19);
|
||||
assert_eq!(gate.wire_chunks_equal(4), 23);
|
||||
assert_eq!(gate.wire_intermediate_value(0), 24);
|
||||
assert_eq!(gate.wire_intermediate_value(4), 28);
|
||||
assert_eq!(gate.wire_most_significant_diff_bit(0), 29);
|
||||
assert_eq!(gate.wire_most_significant_diff_bit(8), 37);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
let num_bits = 40;
|
||||
let num_chunks = 5;
|
||||
|
||||
test_low_degree::<GoldilocksField, _, 4>(ComparisonGate::<_, 4>::new(num_bits, num_chunks))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
let num_bits = 40;
|
||||
let num_chunks = 5;
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
|
||||
test_eval_fns::<F, C, _, D>(ComparisonGate::<_, 2>::new(num_bits, num_chunks))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
type FF = <C as GenericConfig<D>>::FE;
|
||||
|
||||
let num_bits = 40;
|
||||
let num_chunks = 5;
|
||||
let chunk_bits = num_bits / num_chunks;
|
||||
|
||||
// Returns the local wires for a comparison gate given the two inputs.
|
||||
let get_wires = |first_input: F, second_input: F| -> Vec<FF> {
|
||||
let mut v = Vec::new();
|
||||
|
||||
let first_input_u64 = first_input.to_canonical_u64();
|
||||
let second_input_u64 = second_input.to_canonical_u64();
|
||||
|
||||
let result_bool = F::from_bool(first_input_u64 <= second_input_u64);
|
||||
|
||||
let chunk_size = 1 << chunk_bits;
|
||||
let mut first_input_chunks: Vec<F> = (0..num_chunks)
|
||||
.scan(first_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
let mut second_input_chunks: Vec<F> = (0..num_chunks)
|
||||
.scan(second_input_u64, |acc, _| {
|
||||
let tmp = *acc % chunk_size;
|
||||
*acc /= chunk_size;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut chunks_equal: Vec<F> = (0..num_chunks)
|
||||
.map(|i| F::from_bool(first_input_chunks[i] == second_input_chunks[i]))
|
||||
.collect();
|
||||
let mut equality_dummies: Vec<F> = first_input_chunks
|
||||
.iter()
|
||||
.zip(second_input_chunks.iter())
|
||||
.map(|(&f, &s)| if f == s { F::ONE } else { F::ONE / (s - f) })
|
||||
.collect();
|
||||
|
||||
let mut most_significant_diff_so_far = F::ZERO;
|
||||
let mut intermediate_values = Vec::new();
|
||||
for i in 0..num_chunks {
|
||||
if first_input_chunks[i] != second_input_chunks[i] {
|
||||
most_significant_diff_so_far = second_input_chunks[i] - first_input_chunks[i];
|
||||
intermediate_values.push(F::ZERO);
|
||||
} else {
|
||||
intermediate_values.push(most_significant_diff_so_far);
|
||||
}
|
||||
}
|
||||
let most_significant_diff = most_significant_diff_so_far;
|
||||
|
||||
let two_n_plus_msd =
|
||||
(1 << chunk_bits) as u64 + most_significant_diff.to_canonical_u64();
|
||||
let mut msd_bits: Vec<F> = (0..chunk_bits + 1)
|
||||
.scan(two_n_plus_msd, |acc, _| {
|
||||
let tmp = *acc % 2;
|
||||
*acc /= 2;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
v.push(first_input);
|
||||
v.push(second_input);
|
||||
v.push(result_bool);
|
||||
v.push(most_significant_diff);
|
||||
v.append(&mut first_input_chunks);
|
||||
v.append(&mut second_input_chunks);
|
||||
v.append(&mut equality_dummies);
|
||||
v.append(&mut chunks_equal);
|
||||
v.append(&mut intermediate_values);
|
||||
v.append(&mut msd_bits);
|
||||
|
||||
v.iter().map(|&x| x.into()).collect()
|
||||
};
|
||||
|
||||
let mut rng = OsRng;
|
||||
let max: u64 = 1 << (num_bits - 1);
|
||||
let first_input_u64 = rng.gen_range(0..max);
|
||||
let second_input_u64 = {
|
||||
let mut val = rng.gen_range(0..max);
|
||||
while val < first_input_u64 {
|
||||
val = rng.gen_range(0..max);
|
||||
}
|
||||
val
|
||||
};
|
||||
|
||||
let first_input = F::from_canonical_u64(first_input_u64);
|
||||
let second_input = F::from_canonical_u64(second_input_u64);
|
||||
|
||||
let less_than_gate = ComparisonGate::<F, D> {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
let less_than_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(first_input, second_input),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
assert!(
|
||||
less_than_gate
|
||||
.eval_unfiltered(less_than_vars)
|
||||
.iter()
|
||||
.all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
|
||||
let equal_gate = ComparisonGate::<F, D> {
|
||||
num_bits,
|
||||
num_chunks,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
let equal_vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(first_input, first_input),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
assert!(
|
||||
equal_gate
|
||||
.eval_unfiltered(equal_vars)
|
||||
.iter()
|
||||
.all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
}
|
||||
5
src/gates/mod.rs
Normal file
5
src/gates/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod add_many_u32;
|
||||
pub mod arithmetic_u32;
|
||||
pub mod comparison;
|
||||
pub mod range_check_u32;
|
||||
pub mod subtraction_u32;
|
||||
307
src/gates/range_check_u32.rs
Normal file
307
src/gates/range_check_u32.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{format, vec};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::field::types::Field;
|
||||
use plonky2::gates::gate::Gate;
|
||||
use plonky2::gates::util::StridedConstraintConsumer;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::ext_target::ExtensionTarget;
|
||||
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit};
|
||||
use plonky2::plonk::vars::{EvaluationTargets, EvaluationVars, EvaluationVarsBase};
|
||||
use plonky2::util::ceil_div_usize;
|
||||
|
||||
/// A gate which can decompose a number into base B little-endian limbs.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct U32RangeCheckGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub num_input_limbs: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> U32RangeCheckGate<F, D> {
|
||||
pub fn new(num_input_limbs: usize) -> Self {
|
||||
Self {
|
||||
num_input_limbs,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub const AUX_LIMB_BITS: usize = 2;
|
||||
pub const BASE: usize = 1 << Self::AUX_LIMB_BITS;
|
||||
|
||||
fn aux_limbs_per_input_limb(&self) -> usize {
|
||||
ceil_div_usize(32, Self::AUX_LIMB_BITS)
|
||||
}
|
||||
pub fn wire_ith_input_limb(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_input_limbs);
|
||||
i
|
||||
}
|
||||
pub fn wire_ith_input_limb_jth_aux_limb(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_input_limbs);
|
||||
debug_assert!(j < self.aux_limbs_per_input_limb());
|
||||
self.num_input_limbs + self.aux_limbs_per_input_limb() * i + j
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32RangeCheckGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{self:?}")
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let base = F::Extension::from_canonical_usize(Self::BASE);
|
||||
for i in 0..self.num_input_limbs {
|
||||
let input_limb = vars.local_wires[self.wire_ith_input_limb(i)];
|
||||
let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb())
|
||||
.map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)])
|
||||
.collect();
|
||||
let computed_sum = reduce_with_powers(&aux_limbs, base);
|
||||
|
||||
constraints.push(computed_sum - input_limb);
|
||||
for aux_limb in aux_limbs {
|
||||
constraints.push(
|
||||
(0..Self::BASE)
|
||||
.map(|i| aux_limb - F::Extension::from_canonical_usize(i))
|
||||
.product(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
vars: EvaluationVarsBase<F>,
|
||||
mut yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
let base = F::from_canonical_usize(Self::BASE);
|
||||
for i in 0..self.num_input_limbs {
|
||||
let input_limb = vars.local_wires[self.wire_ith_input_limb(i)];
|
||||
let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb())
|
||||
.map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)])
|
||||
.collect();
|
||||
let computed_sum = reduce_with_powers(&aux_limbs, base);
|
||||
|
||||
yield_constr.one(computed_sum - input_limb);
|
||||
for aux_limb in aux_limbs {
|
||||
yield_constr.one(
|
||||
(0..Self::BASE)
|
||||
.map(|i| aux_limb - F::from_canonical_usize(i))
|
||||
.product(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_unfiltered_circuit(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
|
||||
let base = builder.constant(F::from_canonical_usize(Self::BASE));
|
||||
for i in 0..self.num_input_limbs {
|
||||
let input_limb = vars.local_wires[self.wire_ith_input_limb(i)];
|
||||
let aux_limbs: Vec<_> = (0..self.aux_limbs_per_input_limb())
|
||||
.map(|j| vars.local_wires[self.wire_ith_input_limb_jth_aux_limb(i, j)])
|
||||
.collect();
|
||||
let computed_sum = reduce_with_powers_ext_circuit(builder, &aux_limbs, base);
|
||||
|
||||
constraints.push(builder.sub_extension(computed_sum, input_limb));
|
||||
for aux_limb in aux_limbs {
|
||||
constraints.push({
|
||||
let mut acc = builder.one_extension();
|
||||
(0..Self::BASE).for_each(|i| {
|
||||
// We update our accumulator as:
|
||||
// acc' = acc (x - i)
|
||||
// = acc x + (-i) acc
|
||||
// Since -i is constant, we can do this in one arithmetic_extension call.
|
||||
let neg_i = -F::from_canonical_usize(i);
|
||||
acc = builder.arithmetic_extension(F::ONE, neg_i, acc, aux_limb, acc)
|
||||
});
|
||||
acc
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(&self, row: usize, _local_constants: &[F]) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
let gen = U32RangeCheckGenerator { gate: *self, row };
|
||||
vec![Box::new(gen.adapter())]
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.num_input_limbs * (1 + self.aux_limbs_per_input_limb())
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
// Bounded by the range-check (x-0)*(x-1)*...*(x-BASE+1).
|
||||
fn degree(&self) -> usize {
|
||||
Self::BASE
|
||||
}
|
||||
|
||||
// 1 for checking the each sum of aux limbs, plus a range check for each aux limb.
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_input_limbs * (1 + self.aux_limbs_per_input_limb())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct U32RangeCheckGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate: U32RangeCheckGate<F, D>,
|
||||
row: usize,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for U32RangeCheckGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let num_input_limbs = self.gate.num_input_limbs;
|
||||
(0..num_input_limbs)
|
||||
.map(|i| Target::wire(self.row, self.gate.wire_ith_input_limb(i)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let num_input_limbs = self.gate.num_input_limbs;
|
||||
for i in 0..num_input_limbs {
|
||||
let sum_value = witness
|
||||
.get_target(Target::wire(self.row, self.gate.wire_ith_input_limb(i)))
|
||||
.to_canonical_u64() as u32;
|
||||
|
||||
let base = U32RangeCheckGate::<F, D>::BASE as u32;
|
||||
let limbs = (0..self.gate.aux_limbs_per_input_limb())
|
||||
.map(|j| Target::wire(self.row, self.gate.wire_ith_input_limb_jth_aux_limb(i, j)));
|
||||
let limbs_value = (0..self.gate.aux_limbs_per_input_limb())
|
||||
.scan(sum_value, |acc, _| {
|
||||
let tmp = *acc % base;
|
||||
*acc /= base;
|
||||
Some(F::from_canonical_u32(tmp))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (b, b_value) in limbs.zip(limbs_value) {
|
||||
out_buffer.set_target(b, b_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use itertools::unfold;
|
||||
use plonky2::field::extension::quartic::QuarticExtension;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::field::types::{Field, Sample};
|
||||
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use plonky2::hash::hash_types::HashOut;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(U32RangeCheckGate::new(8))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
test_eval_fns::<F, C, _, D>(U32RangeCheckGate::new(8))
|
||||
}
|
||||
|
||||
fn test_gate_constraint(input_limbs: Vec<u64>) {
|
||||
type F = GoldilocksField;
|
||||
type FF = QuarticExtension<GoldilocksField>;
|
||||
const D: usize = 4;
|
||||
const AUX_LIMB_BITS: usize = 2;
|
||||
const BASE: usize = 1 << AUX_LIMB_BITS;
|
||||
const AUX_LIMBS_PER_INPUT_LIMB: usize = ceil_div_usize(32, AUX_LIMB_BITS);
|
||||
|
||||
fn get_wires(input_limbs: Vec<u64>) -> Vec<FF> {
|
||||
let num_input_limbs = input_limbs.len();
|
||||
let mut v = Vec::new();
|
||||
|
||||
for i in 0..num_input_limbs {
|
||||
let input_limb = input_limbs[i];
|
||||
|
||||
let split_to_limbs = |mut val, num| {
|
||||
unfold((), move |_| {
|
||||
let ret = val % (BASE as u64);
|
||||
val /= BASE as u64;
|
||||
Some(ret)
|
||||
})
|
||||
.take(num)
|
||||
.map(F::from_canonical_u64)
|
||||
};
|
||||
|
||||
let mut aux_limbs: Vec<_> =
|
||||
split_to_limbs(input_limb, AUX_LIMBS_PER_INPUT_LIMB).collect();
|
||||
|
||||
v.append(&mut aux_limbs);
|
||||
}
|
||||
|
||||
input_limbs
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(F::from_canonical_u64)
|
||||
.chain(v.iter().cloned())
|
||||
.map(|x| x.into())
|
||||
.collect()
|
||||
}
|
||||
|
||||
let gate = U32RangeCheckGate::<F, D> {
|
||||
num_input_limbs: 8,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(input_limbs),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint_good() {
|
||||
let mut rng = OsRng;
|
||||
let input_limbs: Vec<_> = (0..8).map(|_| rng.gen::<u32>() as u64).collect();
|
||||
|
||||
test_gate_constraint(input_limbs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_gate_constraint_bad() {
|
||||
let mut rng = OsRng;
|
||||
let input_limbs: Vec<_> = (0..8).map(|_| rng.gen()).collect();
|
||||
|
||||
test_gate_constraint(input_limbs);
|
||||
}
|
||||
}
|
||||
445
src/gates/subtraction_u32.rs
Normal file
445
src/gates/subtraction_u32.rs
Normal file
@@ -0,0 +1,445 @@
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{format, vec};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use plonky2::field::extension::Extendable;
|
||||
use plonky2::field::packed::PackedField;
|
||||
use plonky2::field::types::Field;
|
||||
use plonky2::gates::gate::Gate;
|
||||
use plonky2::gates::packed_util::PackedEvaluableBase;
|
||||
use plonky2::gates::util::StridedConstraintConsumer;
|
||||
use plonky2::hash::hash_types::RichField;
|
||||
use plonky2::iop::ext_target::ExtensionTarget;
|
||||
use plonky2::iop::generator::{GeneratedValues, SimpleGenerator, WitnessGenerator};
|
||||
use plonky2::iop::target::Target;
|
||||
use plonky2::iop::wire::Wire;
|
||||
use plonky2::iop::witness::{PartitionWitness, Witness, WitnessWrite};
|
||||
use plonky2::plonk::circuit_builder::CircuitBuilder;
|
||||
use plonky2::plonk::circuit_data::CircuitConfig;
|
||||
use plonky2::plonk::vars::{
|
||||
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
|
||||
EvaluationVarsBasePacked,
|
||||
};
|
||||
|
||||
/// A gate to perform a subtraction on 32-bit limbs: given `x`, `y`, and `borrow`, it returns
|
||||
/// the result `x - y - borrow` and, if this underflows, a new `borrow`. Inputs are not range-checked.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct U32SubtractionGate<F: RichField + Extendable<D>, const D: usize> {
|
||||
pub num_ops: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> U32SubtractionGate<F, D> {
|
||||
pub fn new_from_config(config: &CircuitConfig) -> Self {
|
||||
Self {
|
||||
num_ops: Self::num_ops(config),
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn num_ops(config: &CircuitConfig) -> usize {
|
||||
let wires_per_op = 5 + Self::num_limbs();
|
||||
let routed_wires_per_op = 5;
|
||||
(config.num_wires / wires_per_op).min(config.num_routed_wires / routed_wires_per_op)
|
||||
}
|
||||
|
||||
pub fn wire_ith_input_x(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i
|
||||
}
|
||||
pub fn wire_ith_input_y(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 1
|
||||
}
|
||||
pub fn wire_ith_input_borrow(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 2
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_result(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 3
|
||||
}
|
||||
pub fn wire_ith_output_borrow(&self, i: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
5 * i + 4
|
||||
}
|
||||
|
||||
pub fn limb_bits() -> usize {
|
||||
2
|
||||
}
|
||||
// We have limbs for the 32 bits of `output_result`.
|
||||
pub fn num_limbs() -> usize {
|
||||
32 / Self::limb_bits()
|
||||
}
|
||||
|
||||
pub fn wire_ith_output_jth_limb(&self, i: usize, j: usize) -> usize {
|
||||
debug_assert!(i < self.num_ops);
|
||||
debug_assert!(j < Self::num_limbs());
|
||||
5 * self.num_ops + Self::num_limbs() * i + j
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for U32SubtractionGate<F, D> {
|
||||
fn id(&self) -> String {
|
||||
format!("{self:?}")
|
||||
}
|
||||
|
||||
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let base = F::Extension::from_canonical_u64(1 << 32u64);
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
constraints.push(output_result - (result_initial + base * output_borrow));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = F::Extension::ZERO;
|
||||
let limb_base = F::Extension::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::Extension::from_canonical_usize(x))
|
||||
.product();
|
||||
constraints.push(product);
|
||||
|
||||
combined_limbs = limb_base * combined_limbs + this_limb;
|
||||
}
|
||||
constraints.push(combined_limbs - output_result);
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
constraints.push(output_borrow * (F::Extension::ONE - output_borrow));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_one(
|
||||
&self,
|
||||
_vars: EvaluationVarsBase<F>,
|
||||
_yield_constr: StridedConstraintConsumer<F>,
|
||||
) {
|
||||
panic!("use eval_unfiltered_base_packed instead");
|
||||
}
|
||||
|
||||
fn eval_unfiltered_base_batch(&self, vars_base: EvaluationVarsBaseBatch<F>) -> Vec<F> {
|
||||
self.eval_unfiltered_base_batch_packed(vars_base)
|
||||
}
|
||||
|
||||
fn eval_unfiltered_circuit(
|
||||
&self,
|
||||
builder: &mut CircuitBuilder<F, D>,
|
||||
vars: EvaluationTargets<D>,
|
||||
) -> Vec<ExtensionTarget<D>> {
|
||||
let mut constraints = Vec::with_capacity(self.num_constraints());
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
|
||||
|
||||
let diff = builder.sub_extension(input_x, input_y);
|
||||
let result_initial = builder.sub_extension(diff, input_borrow);
|
||||
let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << 32u64));
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
let computed_output = builder.mul_add_extension(base, output_borrow, result_initial);
|
||||
constraints.push(builder.sub_extension(output_result, computed_output));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = builder.zero_extension();
|
||||
let limb_base = builder
|
||||
.constant_extension(F::Extension::from_canonical_u64(1u64 << Self::limb_bits()));
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let mut product = builder.one_extension();
|
||||
for x in 0..max_limb {
|
||||
let x_target =
|
||||
builder.constant_extension(F::Extension::from_canonical_usize(x));
|
||||
let diff = builder.sub_extension(this_limb, x_target);
|
||||
product = builder.mul_extension(product, diff);
|
||||
}
|
||||
constraints.push(product);
|
||||
|
||||
combined_limbs = builder.mul_add_extension(limb_base, combined_limbs, this_limb);
|
||||
}
|
||||
constraints.push(builder.sub_extension(combined_limbs, output_result));
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
let one = builder.one_extension();
|
||||
let not_borrow = builder.sub_extension(one, output_borrow);
|
||||
constraints.push(builder.mul_extension(output_borrow, not_borrow));
|
||||
}
|
||||
|
||||
constraints
|
||||
}
|
||||
|
||||
fn generators(&self, row: usize, _local_constants: &[F]) -> Vec<Box<dyn WitnessGenerator<F>>> {
|
||||
(0..self.num_ops)
|
||||
.map(|i| {
|
||||
let g: Box<dyn WitnessGenerator<F>> = Box::new(
|
||||
U32SubtractionGenerator {
|
||||
gate: *self,
|
||||
row,
|
||||
i,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
.adapter(),
|
||||
);
|
||||
g
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn num_wires(&self) -> usize {
|
||||
self.num_ops * (5 + Self::num_limbs())
|
||||
}
|
||||
|
||||
fn num_constants(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn degree(&self) -> usize {
|
||||
1 << Self::limb_bits()
|
||||
}
|
||||
|
||||
fn num_constraints(&self) -> usize {
|
||||
self.num_ops * (3 + Self::num_limbs())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> PackedEvaluableBase<F, D>
|
||||
for U32SubtractionGate<F, D>
|
||||
{
|
||||
fn eval_unfiltered_base_packed<P: PackedField<Scalar = F>>(
|
||||
&self,
|
||||
vars: EvaluationVarsBasePacked<P>,
|
||||
mut yield_constr: StridedConstraintConsumer<P>,
|
||||
) {
|
||||
for i in 0..self.num_ops {
|
||||
let input_x = vars.local_wires[self.wire_ith_input_x(i)];
|
||||
let input_y = vars.local_wires[self.wire_ith_input_y(i)];
|
||||
let input_borrow = vars.local_wires[self.wire_ith_input_borrow(i)];
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
|
||||
let output_result = vars.local_wires[self.wire_ith_output_result(i)];
|
||||
let output_borrow = vars.local_wires[self.wire_ith_output_borrow(i)];
|
||||
|
||||
yield_constr.one(output_result - (result_initial + output_borrow * base));
|
||||
|
||||
// Range-check output_result to be at most 32 bits.
|
||||
let mut combined_limbs = P::ZEROS;
|
||||
let limb_base = F::from_canonical_u64(1u64 << Self::limb_bits());
|
||||
for j in (0..Self::num_limbs()).rev() {
|
||||
let this_limb = vars.local_wires[self.wire_ith_output_jth_limb(i, j)];
|
||||
let max_limb = 1 << Self::limb_bits();
|
||||
let product = (0..max_limb)
|
||||
.map(|x| this_limb - F::from_canonical_usize(x))
|
||||
.product();
|
||||
yield_constr.one(product);
|
||||
|
||||
combined_limbs = combined_limbs * limb_base + this_limb;
|
||||
}
|
||||
yield_constr.one(combined_limbs - output_result);
|
||||
|
||||
// Range-check output_borrow to be one bit.
|
||||
yield_constr.one(output_borrow * (P::ONES - output_borrow));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct U32SubtractionGenerator<F: RichField + Extendable<D>, const D: usize> {
|
||||
gate: U32SubtractionGate<F, D>,
|
||||
row: usize,
|
||||
i: usize,
|
||||
_phantom: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: RichField + Extendable<D>, const D: usize> SimpleGenerator<F>
|
||||
for U32SubtractionGenerator<F, D>
|
||||
{
|
||||
fn dependencies(&self) -> Vec<Target> {
|
||||
let local_target = |column| Target::wire(self.row, column);
|
||||
|
||||
vec![
|
||||
local_target(self.gate.wire_ith_input_x(self.i)),
|
||||
local_target(self.gate.wire_ith_input_y(self.i)),
|
||||
local_target(self.gate.wire_ith_input_borrow(self.i)),
|
||||
]
|
||||
}
|
||||
|
||||
fn run_once(&self, witness: &PartitionWitness<F>, out_buffer: &mut GeneratedValues<F>) {
|
||||
let local_wire = |column| Wire {
|
||||
row: self.row,
|
||||
column,
|
||||
};
|
||||
|
||||
let get_local_wire = |column| witness.get_wire(local_wire(column));
|
||||
|
||||
let input_x = get_local_wire(self.gate.wire_ith_input_x(self.i));
|
||||
let input_y = get_local_wire(self.gate.wire_ith_input_y(self.i));
|
||||
let input_borrow = get_local_wire(self.gate.wire_ith_input_borrow(self.i));
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let result_initial_u64 = result_initial.to_canonical_u64();
|
||||
let output_borrow = if result_initial_u64 > 1 << 32u64 {
|
||||
F::ONE
|
||||
} else {
|
||||
F::ZERO
|
||||
};
|
||||
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let output_result = result_initial + base * output_borrow;
|
||||
|
||||
let output_result_wire = local_wire(self.gate.wire_ith_output_result(self.i));
|
||||
let output_borrow_wire = local_wire(self.gate.wire_ith_output_borrow(self.i));
|
||||
|
||||
out_buffer.set_wire(output_result_wire, output_result);
|
||||
out_buffer.set_wire(output_borrow_wire, output_borrow);
|
||||
|
||||
let output_result_u64 = output_result.to_canonical_u64();
|
||||
|
||||
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << U32SubtractionGate::<F, D>::limb_bits();
|
||||
let output_limbs: Vec<_> = (0..num_limbs)
|
||||
.scan(output_result_u64, |acc, _| {
|
||||
let tmp = *acc % limb_base;
|
||||
*acc /= limb_base;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
for j in 0..num_limbs {
|
||||
let wire = local_wire(self.gate.wire_ith_output_jth_limb(self.i, j));
|
||||
out_buffer.set_wire(wire, output_limbs[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use plonky2::field::extension::quartic::QuarticExtension;
|
||||
use plonky2::field::goldilocks_field::GoldilocksField;
|
||||
use plonky2::field::types::{PrimeField64, Sample};
|
||||
use plonky2::gates::gate_testing::{test_eval_fns, test_low_degree};
|
||||
use plonky2::hash::hash_types::HashOut;
|
||||
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
|
||||
use rand::rngs::OsRng;
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn low_degree() {
|
||||
test_low_degree::<GoldilocksField, _, 4>(U32SubtractionGate::<GoldilocksField, 4> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eval_fns() -> Result<()> {
|
||||
const D: usize = 2;
|
||||
type C = PoseidonGoldilocksConfig;
|
||||
type F = <C as GenericConfig<D>>::F;
|
||||
test_eval_fns::<F, C, _, D>(U32SubtractionGate::<GoldilocksField, D> {
|
||||
num_ops: 3,
|
||||
_phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gate_constraint() {
|
||||
type F = GoldilocksField;
|
||||
type FF = QuarticExtension<GoldilocksField>;
|
||||
const D: usize = 4;
|
||||
const NUM_U32_SUBTRACTION_OPS: usize = 3;
|
||||
|
||||
fn get_wires(inputs_x: Vec<u64>, inputs_y: Vec<u64>, borrows: Vec<u64>) -> Vec<FF> {
|
||||
let mut v0 = Vec::new();
|
||||
let mut v1 = Vec::new();
|
||||
|
||||
let limb_bits = U32SubtractionGate::<F, D>::limb_bits();
|
||||
let num_limbs = U32SubtractionGate::<F, D>::num_limbs();
|
||||
let limb_base = 1 << limb_bits;
|
||||
for c in 0..NUM_U32_SUBTRACTION_OPS {
|
||||
let input_x = F::from_canonical_u64(inputs_x[c]);
|
||||
let input_y = F::from_canonical_u64(inputs_y[c]);
|
||||
let input_borrow = F::from_canonical_u64(borrows[c]);
|
||||
|
||||
let result_initial = input_x - input_y - input_borrow;
|
||||
let result_initial_u64 = result_initial.to_canonical_u64();
|
||||
let output_borrow = if result_initial_u64 > 1 << 32u64 {
|
||||
F::ONE
|
||||
} else {
|
||||
F::ZERO
|
||||
};
|
||||
|
||||
let base = F::from_canonical_u64(1 << 32u64);
|
||||
let output_result = result_initial + base * output_borrow;
|
||||
|
||||
let output_result_u64 = output_result.to_canonical_u64();
|
||||
|
||||
let mut output_limbs: Vec<_> = (0..num_limbs)
|
||||
.scan(output_result_u64, |acc, _| {
|
||||
let tmp = *acc % limb_base;
|
||||
*acc /= limb_base;
|
||||
Some(F::from_canonical_u64(tmp))
|
||||
})
|
||||
.collect();
|
||||
|
||||
v0.push(input_x);
|
||||
v0.push(input_y);
|
||||
v0.push(input_borrow);
|
||||
v0.push(output_result);
|
||||
v0.push(output_borrow);
|
||||
v1.append(&mut output_limbs);
|
||||
}
|
||||
|
||||
v0.iter().chain(v1.iter()).map(|&x| x.into()).collect()
|
||||
}
|
||||
|
||||
let mut rng = OsRng;
|
||||
let inputs_x = (0..NUM_U32_SUBTRACTION_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
let inputs_y = (0..NUM_U32_SUBTRACTION_OPS)
|
||||
.map(|_| rng.gen::<u32>() as u64)
|
||||
.collect();
|
||||
let borrows = (0..NUM_U32_SUBTRACTION_OPS)
|
||||
.map(|_| (rng.gen::<u32>() % 2) as u64)
|
||||
.collect();
|
||||
|
||||
let gate = U32SubtractionGate::<F, D> {
|
||||
num_ops: NUM_U32_SUBTRACTION_OPS,
|
||||
_phantom: PhantomData,
|
||||
};
|
||||
|
||||
let vars = EvaluationVars {
|
||||
local_constants: &[],
|
||||
local_wires: &get_wires(inputs_x, inputs_y, borrows),
|
||||
public_inputs_hash: &HashOut::rand(),
|
||||
};
|
||||
|
||||
assert!(
|
||||
gate.eval_unfiltered(vars).iter().all(|x| x.is_zero()),
|
||||
"Gate constraints are not satisfied."
|
||||
);
|
||||
}
|
||||
}
|
||||
8
src/lib.rs
Normal file
8
src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
#![no_std]
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
pub mod gadgets;
|
||||
pub mod gates;
|
||||
pub mod witness;
|
||||
33
src/witness.rs
Normal file
33
src/witness.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use plonky2::field::types::{Field, PrimeField64};
|
||||
use plonky2::iop::generator::GeneratedValues;
|
||||
use plonky2::iop::witness::{Witness, WitnessWrite};
|
||||
|
||||
use crate::gadgets::arithmetic_u32::U32Target;
|
||||
|
||||
pub trait WitnessU32<F: PrimeField64>: Witness<F> {
|
||||
fn set_u32_target(&mut self, target: U32Target, value: u32);
|
||||
fn get_u32_target(&self, target: U32Target) -> (u32, u32);
|
||||
}
|
||||
|
||||
impl<T: Witness<F>, F: PrimeField64> WitnessU32<F> for T {
|
||||
fn set_u32_target(&mut self, target: U32Target, value: u32) {
|
||||
self.set_target(target.0, F::from_canonical_u32(value));
|
||||
}
|
||||
|
||||
fn get_u32_target(&self, target: U32Target) -> (u32, u32) {
|
||||
let x_u64 = self.get_target(target.0).to_canonical_u64();
|
||||
let low = x_u64 as u32;
|
||||
let high = (x_u64 >> 32) as u32;
|
||||
(low, high)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait GeneratedValuesU32<F: Field> {
|
||||
fn set_u32_target(&mut self, target: U32Target, value: u32);
|
||||
}
|
||||
|
||||
impl<F: Field> GeneratedValuesU32<F> for GeneratedValues<F> {
|
||||
fn set_u32_target(&mut self, target: U32Target, value: u32) {
|
||||
self.set_target(target.0, F::from_canonical_u32(value))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user