Clean up code, still works

This commit is contained in:
Brian Lawrence
2024-09-30 15:18:48 -07:00
parent be2eec8b9c
commit e8aaa32322
2 changed files with 110 additions and 115 deletions

View File

@@ -55,7 +55,6 @@ impl SchnorrSigner{
pub fn keygen(&self, sk: &SchnorrSecretKey) -> SchnorrPublicKey { pub fn keygen(&self, sk: &SchnorrSecretKey) -> SchnorrPublicKey {
let pk: GoldilocksField = Self::pow(self.PRIME_GROUP_GEN, sk.sk).inverse(); let pk: GoldilocksField = Self::pow(self.PRIME_GROUP_GEN, sk.sk).inverse();
println!("{:?}", self.PRIME_GROUP_GEN);
// self.PRIME_GROUP_GEN is 6612579038192137166 // self.PRIME_GROUP_GEN is 6612579038192137166
SchnorrPublicKey{pk: pk} SchnorrPublicKey{pk: pk}
} }
@@ -66,7 +65,6 @@ impl SchnorrSigner{
.copied() .copied()
.collect(); .collect();
println!("Running hash on concatenated elts: {:?}", poseidon_input);
let h = PoseidonHash::hash_no_pad(&poseidon_input); let h = PoseidonHash::hash_no_pad(&poseidon_input);
h.elements[0].to_canonical_u64() % self.PRIME_GROUP_ORDER h.elements[0].to_canonical_u64() % self.PRIME_GROUP_ORDER
} }
@@ -89,9 +87,6 @@ impl SchnorrSigner{
assert!(k < self.PRIME_GROUP_ORDER); assert!(k < self.PRIME_GROUP_ORDER);
assert!(sk.sk < self.PRIME_GROUP_ORDER); assert!(sk.sk < self.PRIME_GROUP_ORDER);
assert!(e < self.PRIME_GROUP_ORDER); assert!(e < self.PRIME_GROUP_ORDER);
//println!("Super secret k: {:?}", k);
//println!("Super secret r: {:?}", r);
//println!("PRIME_GROUP_ORDER: {:?}", self.PRIME_GROUP_ORDER);
let mut s128: u128 = ((k as u128) + (sk.sk as u128) * (e as u128)); let mut s128: u128 = ((k as u128) + (sk.sk as u128) * (e as u128));
s128 %= self.PRIME_GROUP_ORDER as u128; s128 %= self.PRIME_GROUP_ORDER as u128;
let s: u64 = s128 as u64; let s: u64 = s128 as u64;

View File

@@ -32,6 +32,111 @@ use crate::schnorr::{SchnorrPublicKey, SchnorrSignature};
type GoldF = GoldilocksField; type GoldF = GoldilocksField;
#[derive(Debug, Default)]
pub struct Mod65537Generator {
a: Target,
q: Target,
r: Target,
}
impl SimpleGenerator<GoldF, 2> for Mod65537Generator {
fn id(&self) -> String {
"Mod65537Generator".to_string()
}
fn dependencies(&self) -> Vec<Target> {
vec![self.a]
}
fn run_once(
&self,
witness: &PartitionWitness<GoldF>,
out_buffer: &mut GeneratedValues<GoldF>,
) -> Result<()> {
let a = witness.get_target(self.a);
let a64 = a.to_canonical_u64();
let q64 = a64 / 65537;
let r64 = a64 % 65537;
out_buffer.set_target(self.q, GoldF::from_canonical_u64(q64));
out_buffer.set_target(self.r, GoldF::from_canonical_u64(r64));
Ok(())
}
fn serialize(&self, dst: &mut Vec<u8>, common_data: &CommonCircuitData<GoldF, 2>) -> IoResult<()> {
dst.write_target(self.a)?;
dst.write_target(self.q)?;
dst.write_target(self.r)?;
Ok(())
}
fn deserialize(src: &mut Buffer, common_data: &CommonCircuitData<GoldF, 2>) -> IoResult<Self>
where
Self: Sized
{
let a = src.read_target()?;
let q = src.read_target()?;
let r = src.read_target()?;
Ok(Self { a, q, r })
}
}
pub struct Mod65537Builder {}
impl Mod65537Builder {
// Reduce a modulo the constant 65537
// where a is the canonical representative for an element of the field
// (meaning: 0 \leq a < p)
// To prove this, write
// a = 65537 * q + r, and do range checks to check that:
// 0 <= q <= floor(p / 65537)
// 0 <= r < 65537
// (these first two checks guarantee that a lies in the range [0, p + 65536])
// if q = floor(p / 65537) then r = 0
// (note that p % 65537 == 1 so this is the only possibility)
pub(crate) fn mod_65537 (
builder: &mut CircuitBuilder::<GoldF, 2>,
a: Target,
) -> Target {
let q = builder.add_virtual_target();
let r = builder.add_virtual_target();
// the Mod65537Generator will assign values to q and r later
builder.add_simple_generator( Mod65537Generator { a, q, r } );
// impose four constraints
// 1. a = 65537 * q + r
let t65537 = builder.constant(GoldF::from_canonical_u64(65537));
let a_copy = builder.mul_add(t65537, q, r);
builder.connect(a, a_copy);
// 2. 0 <= q <= floor(p / 65537)
// max_q is 281470681743360 = floor(p / 65537) = (p-1) / 65537 = 2^48 - 2^32
let max_q = builder.constant(GoldF::from_canonical_u64(281470681743360));
builder.range_check(q, 48);
let diff_q = builder.sub(max_q, q);
builder.range_check(diff_q, 48);
// 3. 0 <= r < 65537
let max_r = builder.constant(GoldF::from_canonical_u64(65537));
builder.range_check(r, 17);
let diff_r = builder.sub(max_r, r);
builder.range_check(diff_r, 17);
// 4. if q = floor(p / 65537) then r = 0
let q_equals_max = builder.is_equal(q, max_q);
let prod_temp = builder.mul(q_equals_max.target, r);
let zero_temp = builder.zero();
builder.connect(prod_temp, zero_temp);
// throw in the Generator to tell builder how to compute r
builder.add_simple_generator( Mod65537Generator {a, q, r} );
r
}
}
pub struct MessageTarget { pub struct MessageTarget {
msg: Vec<Target>, msg: Vec<Target>,
} }
@@ -87,113 +192,10 @@ impl SchnorrPublicKeyTarget {
} }
} }
#[derive(Debug, Default)]
pub struct Mod65537Generator {
a: Target,
q: Target,
r: Target,
}
impl SimpleGenerator<GoldF, 2> for Mod65537Generator {
fn id(&self) -> String {
"Mod65537Generator".to_string()
}
fn dependencies(&self) -> Vec<Target> {
vec![self.a]
}
fn run_once(
&self,
witness: &PartitionWitness<GoldF>,
out_buffer: &mut GeneratedValues<GoldF>,
) -> Result<()> {
let a = witness.get_target(self.a);
let a64 = a.to_canonical_u64();
let q64 = a64 / 65537;
let r64 = a64 % 65537;
out_buffer.set_target(self.q, GoldF::from_canonical_u64(q64));
out_buffer.set_target(self.r, GoldF::from_canonical_u64(r64));
Ok(())
}
fn serialize(&self, dst: &mut Vec<u8>, common_data: &CommonCircuitData<GoldF, 2>) -> IoResult<()> {
println!("SERIALIZATION! What is this good for?");
dst.write_target(self.a)?;
dst.write_target(self.q)?;
dst.write_target(self.r)?;
Ok(())
}
fn deserialize(src: &mut Buffer, common_data: &CommonCircuitData<GoldF, 2>) -> IoResult<Self>
where
Self: Sized
{
println!("DESERIALIZATION! What is this good for?");
let a = src.read_target()?;
let q = src.read_target()?;
let r = src.read_target()?;
Ok(Self { a, q, r })
}
}
pub struct SchnorrBuilder {} pub struct SchnorrBuilder {}
impl SchnorrBuilder { impl SchnorrBuilder {
// Reduce a modulo the constant 65537
// where a is the canonical representative for an element of the field
// (meaning: 0 \leq a < p)
// To verify this, write
// a = 65537 * q + r, and do range checks to check that:
// 0 <= q <= floor(p / 65537)
// 0 <= r < 65537
// (these first two checks guarantee that a lies in the range [0, p + 65536])
// if q = floor(p / 65537) then r = 0
// (note that p % 65537 == 1 so this is the only possibility)
pub(crate) fn mod_65537 <
//C: GenericConfig<2, F = GoldF>,
> (
builder: &mut CircuitBuilder::<GoldF, 2>,
a: Target,
) -> Target {
let q = builder.add_virtual_target();
let r = builder.add_virtual_target();
// the Mod65537Generator will assign values to q and r later
builder.add_simple_generator( Mod65537Generator { a, q, r } );
// impose four constraints
// 1. a = 65537 * q + r
let t65537 = builder.constant(GoldF::from_canonical_u64(65537));
let a_copy = builder.mul_add(t65537, q, r);
builder.connect(a, a_copy);
// 2. 0 <= q <= floor(p / 65537)
// max_q is 281470681743360 = floor(p / 65537) = (p-1) / 65537 = 2^48 - 2^32
let max_q = builder.constant(GoldF::from_canonical_u64(281470681743360));
builder.range_check(q, 48);
let diff_q = builder.sub(max_q, q);
builder.range_check(diff_q, 48);
// 3. 0 <= r < 65537
let max_r = builder.constant(GoldF::from_canonical_u64(65537));
builder.range_check(r, 17);
let diff_r = builder.sub(max_r, r);
builder.range_check(diff_r, 17);
// 4. if q = floor(p / 65537) then r = 0
let q_equals_max = builder.is_equal(q, max_q);
let prod_temp = builder.mul(q_equals_max.target, r);
let zero_temp = builder.zero();
builder.connect(prod_temp, zero_temp);
// throw in the Generator to tell builder how to compute r
builder.add_simple_generator( Mod65537Generator {a, q, r} );
r
}
pub fn constrain_sig < pub fn constrain_sig <
C: GenericConfig<2, F = GoldF>, C: GenericConfig<2, F = GoldF>,
@@ -204,13 +206,12 @@ impl SchnorrBuilder {
msg: &MessageTarget, msg: &MessageTarget,
pk: &SchnorrPublicKeyTarget, pk: &SchnorrPublicKeyTarget,
) -> () { ) -> () {
println!("WARNING constrain_sig() is not done yet DONT USE IT");
let PRIME_GROUP_GEN: Target = builder.constant(GoldF::from_canonical_u64(6612579038192137166)); let PRIME_GROUP_GEN: Target = builder.constant(GoldF::from_canonical_u64(6612579038192137166));
let PRIME_GROUP_ORDER: Target = builder.constant(GoldF::from_canonical_u64(65537)); let PRIME_GROUP_ORDER: Target = builder.constant(GoldF::from_canonical_u64(65537));
const num_bits_exp: usize = 32; const num_bits_exp: usize = 32;
/* /* here's the direct verification calculation,
which we verify in-circuit
let r: GoldF = Self::pow(self.PRIME_GROUP_GEN, sig.s) let r: GoldF = Self::pow(self.PRIME_GROUP_GEN, sig.s)
* Self::pow(pk.pk, sig.e); * Self::pow(pk.pk, sig.e);
let e_v: u64 = self.hash_insecure(&r, msg); let e_v: u64 = self.hash_insecure(&r, msg);
@@ -229,7 +230,7 @@ impl SchnorrBuilder {
hash_input, hash_input,
).elements[0]; // whoops have to take mod group order; ).elements[0]; // whoops have to take mod group order;
let e: Target = Self::mod_65537(builder, hash_output); let e: Target = Mod65537Builder::mod_65537(builder, hash_output);
// enforce equality // enforce equality
builder.connect(e, sig.e); builder.connect(e, sig.e);
@@ -239,7 +240,7 @@ impl SchnorrBuilder {
#[cfg(test)] #[cfg(test)]
mod tests{ mod tests{
use crate::schnorr::{SchnorrPublicKey, SchnorrSecretKey, SchnorrSigner, SchnorrSignature}; use crate::schnorr::{SchnorrPublicKey, SchnorrSecretKey, SchnorrSigner, SchnorrSignature};
use crate::schnorr_prover::{MessageTarget, SchnorrBuilder, SchnorrPublicKeyTarget, SchnorrSignatureTarget}; use crate::schnorr_prover::{MessageTarget, Mod65537Builder, SchnorrBuilder, SchnorrPublicKeyTarget, SchnorrSignatureTarget};
use plonky2::hash::poseidon::Poseidon; use plonky2::hash::poseidon::Poseidon;
use plonky2::iop::{ use plonky2::iop::{
target::Target, target::Target,
@@ -280,7 +281,7 @@ mod tests{
.collect(); .collect();
let r: Vec<Target> = a.iter() let r: Vec<Target> = a.iter()
.map(|targ| SchnorrBuilder::mod_65537(&mut builder, *targ)) .map(|targ| Mod65537Builder::mod_65537(&mut builder, *targ))
.collect(); .collect();
// check that the outputs are correct, // check that the outputs are correct,
@@ -290,7 +291,6 @@ mod tests{
let r_expected: Vec<Target> = r_expected64.iter() let r_expected: Vec<Target> = r_expected64.iter()
.map(|x| builder.constant(GoldilocksField::from_canonical_u64(*x))) .map(|x| builder.constant(GoldilocksField::from_canonical_u64(*x)))
.collect(); .collect();
r.iter().zip(r_expected.iter()) r.iter().zip(r_expected.iter())
.for_each(|(x, y)| builder.connect(*x, *y)); .for_each(|(x, y)| builder.connect(*x, *y));