Browse Source

add some error handling

aggregated-schnorr-musig
arnaucube 4 years ago
parent
commit
addcca64e5
4 changed files with 90 additions and 75 deletions
  1. +1
    -1
      Cargo.toml
  2. +4
    -3
      README.md
  3. +60
    -51
      src/lib.rs
  4. +25
    -20
      src/utils.rs

+ 1
- 1
Cargo.toml

@ -1,6 +1,6 @@
[package]
name = "babyjubjub-rs"
version = "0.0.2"
version = "0.0.3"
authors = ["arnaucube <root@arnaucube.com>"]
edition = "2018"
license = "GPL-3.0"

+ 4
- 3
README.md

@ -1,5 +1,5 @@
# babyjubjub-rs [![Crates.io](https://img.shields.io/crates/v/babyjubjub-rs.svg)](https://crates.io/crates/babyjubjub-rs) [![Build Status](https://travis-ci.org/arnaucube/babyjubjub-rs.svg?branch=master)](https://travis-ci.org/arnaucube/babyjubjub-rs)
BabyJubJub elliptic curve implementation in Rust. Is a twisted edwards curve embedded in the curve of BN128.
BabyJubJub elliptic curve implementation in Rust. A twisted edwards curve embedded in the curve of BN128.
BabyJubJub curve explanation: https://medium.com/zokrates/efficient-ecc-in-zksnarks-using-zokrates-bd9ae37b8186
@ -10,7 +10,7 @@ Uses:
Compatible with the BabyJubJub Go implementation from https://github.com/iden3/go-iden3-crypto
## Warning
Doing this in my free time to get familiar with Rust, do not use in production.
Doing this in my free time to get familiar with Rust, **do not use in production**.
- [x] point addition
- [x] point scalar multiplication
@ -24,7 +24,8 @@ Doing this in my free time to get familiar with Rust, do not use in production.
### References
- BabyJubJub curve explanation: https://medium.com/zokrates/efficient-ecc-in-zksnarks-using-zokrates-bd9ae37b8186
- C++ https://github.com/barryWhiteHat/baby_jubjub_ecc
- C++ & Explanation https://github.com/barryWhiteHat/baby_jubjub
- C++ https://github.com/barryWhiteHat/baby_jubjub_ecc
- Javascript & Circom: https://github.com/iden3/circomlib
- Go https://github.com/iden3/go-iden3-crypto
- JubJub curve explanation: https://z.cash/technology/jubjub/

+ 60
- 51
src/lib.rs

@ -62,7 +62,7 @@ pub struct Point {
}
impl Point {
pub fn add(&self, q: &Point) -> Point {
pub fn add(&self, q: &Point) -> Result<Point, String> {
// x = (x1*y2+y1*x2)/(c*(1+d*x1*x2*y1*y2))
// y = (y1*y2-x1*x2)/(c*(1-d*x1*x2*y1*y2))
@ -70,19 +70,19 @@ impl Point {
let one: BigInt = One::one();
let x_num: BigInt = &self.x * &q.y + &self.y * &q.x;
let x_den: BigInt = &one + &D.clone() * &self.x * &q.x * &self.y * &q.y;
let x_den_inv = utils::modinv(&x_den, &Q);
let x_den_inv = utils::modinv(&x_den, &Q)?;
let x: BigInt = utils::modulus(&(&x_num * &x_den_inv), &Q);
// y = (y1 * y2 - a * x1 * x2) / (1 - d * x1 * x2 * y1 * y2)
let y_num = &self.y * &q.y - &A.clone() * &self.x * &q.x;
let y_den = utils::modulus(&(&one - &D.clone() * &self.x * &q.x * &self.y * &q.y), &Q);
let y_den_inv = utils::modinv(&y_den, &Q);
let y_den_inv = utils::modinv(&y_den, &Q)?;
let y: BigInt = utils::modulus(&(&y_num * &y_den_inv), &Q);
Point { x: x, y: y }
Ok(Point { x: x, y: y })
}
pub fn mul_scalar(&self, n: BigInt) -> Point {
pub fn mul_scalar(&self, n: BigInt) -> Result<Point, String> {
// TODO use & in n to avoid clones on function call
let mut r: Point = Point {
x: Zero::zero(),
@ -96,14 +96,14 @@ impl Point {
while rem != zero {
let is_odd = &rem & &one == one;
if is_odd == true {
r = r.add(&exp);
r = r.add(&exp)?;
}
exp = exp.add(&exp);
exp = exp.add(&exp)?;
rem = rem >> 1;
}
r.x = utils::modulus(&r.x, &Q);
r.y = utils::modulus(&r.y, &Q);
r
Ok(r)
}
pub fn compress(&self) -> [u8; 32] {
@ -111,7 +111,7 @@ impl Point {
let (_, y_bytes) = self.y.to_bytes_le();
let len = min(y_bytes.len(), r.len());
r[..len].copy_from_slice(&y_bytes[..len]);
if &self.x >= &(&Q.clone() >> 1) {
if &self.x > &(&Q.clone() >> 1) {
r[31] = r[31] | 0x80;
}
r
@ -133,18 +133,15 @@ pub fn decompress_point(bb: [u8; 32]) -> Result {
let one: BigInt = One::one();
// x^2 = (1 - y^2) / (a - d * y^2) (mod p)
let mut x: BigInt = utils::modulus(
&((one - utils::modulus(&(&y * &y), &Q))
* utils::modinv(
&utils::modulus(
&(&A.clone() - utils::modulus(&(&D.clone() * (&y * &y)), &Q)),
&Q,
),
&Q,
)),
let den = utils::modinv(
&utils::modulus(
&(&A.clone() - utils::modulus(&(&D.clone() * (&y * &y)), &Q)),
&Q,
),
&Q,
);
x = utils::modsqrt(&x, &Q);
)?;
let mut x: BigInt = utils::modulus(&((one - utils::modulus(&(&y * &y), &Q)) * den), &Q);
x = utils::modsqrt(&x, &Q)?;
if sign && !(&x > &(&Q.clone() >> 1)) || (!sign && (&x > &(&Q.clone() >> 1))) {
x = x * -1.to_bigint().unwrap();
@ -191,10 +188,10 @@ pub struct PrivateKey {
}
impl PrivateKey {
pub fn public(&self) -> Point {
pub fn public(&self) -> Result<Point, String> {
// https://tools.ietf.org/html/rfc8032#section-5.1.5
let pk = B8.mul_scalar(self.key.clone());
pk.clone()
let pk = B8.mul_scalar(self.key.clone())?;
Ok(pk.clone())
}
pub fn sign_mimc(&self, msg: BigInt) -> Result<Signature, String> {
@ -209,15 +206,12 @@ impl PrivateKey {
let r_bytes = utils::concatenate_arrays(s, &msg_bytes);
let mut r = BigInt::from_bytes_be(Sign::Plus, &r_bytes[..]);
r = utils::modulus(&r, &SUBORDER);
let r8: Point = B8.mul_scalar(r.clone());
let a = &self.public();
let r8: Point = B8.mul_scalar(r.clone())?;
let a = &self.public()?;
let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg];
let mimc7 = Mimc7::new();
let hm = match mimc7.hash(hm_input) {
Result::Err(err) => return Err(err.to_string()),
Result::Ok(hm) => hm,
};
let hm = mimc7.hash(hm_input)?;
let mut s = &self.key << 3;
s = hm * s;
@ -241,15 +235,12 @@ impl PrivateKey {
let r_bytes = utils::concatenate_arrays(s, &msg_bytes);
let mut r = BigInt::from_bytes_be(Sign::Plus, &r_bytes[..]);
r = utils::modulus(&r, &SUBORDER);
let r8: Point = B8.mul_scalar(r.clone());
let a = &self.public();
let r8: Point = B8.mul_scalar(r.clone())?;
let a = &self.public()?;
let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg];
let poseidon = Poseidon::new();
let hm = match poseidon.hash(hm_input) {
Result::Err(err) => return Err(err.to_string()),
Result::Ok(hm) => hm,
};
let hm = poseidon.hash(hm_input)?;
let mut s = &self.key << 3;
s = hm * s;
@ -295,8 +286,17 @@ pub fn verify_mimc(pk: Point, sig: Signature, msg: BigInt) -> bool {
Result::Err(_) => return false,
Result::Ok(hm) => hm,
};
let l = B8.mul_scalar(sig.s);
let r = sig.r_b8.add(&pk.mul_scalar(8.to_bigint().unwrap() * hm));
let l = match B8.mul_scalar(sig.s) {
Result::Err(_) => return false,
Result::Ok(l) => l,
};
let r = match sig
.r_b8
.add(&pk.mul_scalar(8.to_bigint().unwrap() * hm).unwrap())
{
Result::Err(_) => return false,
Result::Ok(r) => r,
};
if l.x == r.x && l.y == r.y {
return true;
}
@ -315,8 +315,17 @@ pub fn verify_poseidon(pk: Point, sig: Signature, msg: BigInt) -> bool {
Result::Err(_) => return false,
Result::Ok(hm) => hm,
};
let l = B8.mul_scalar(sig.s);
let r = sig.r_b8.add(&pk.mul_scalar(8.to_bigint().unwrap() * hm));
let l = match B8.mul_scalar(sig.s) {
Result::Err(_) => return false,
Result::Ok(l) => l,
};
let r = match sig
.r_b8
.add(&pk.mul_scalar(8.to_bigint().unwrap() * hm).unwrap())
{
Result::Err(_) => return false,
Result::Ok(r) => r,
};
if l.x == r.x && l.y == r.y {
return true;
}
@ -355,7 +364,7 @@ mod tests {
)
.unwrap(),
};
let res = p.add(&q);
let res = p.add(&q).unwrap();
assert_eq!(
res.x.to_string(),
"6890855772600357754907169075114257697580319025794532037257385534741338397365"
@ -391,7 +400,7 @@ mod tests {
)
.unwrap(),
};
let res = p.add(&q);
let res = p.add(&q).unwrap();
assert_eq!(
res.x.to_string(),
"7916061937171219682591368294088513039687205273691143098332585753343424131937"
@ -416,9 +425,9 @@ mod tests {
)
.unwrap(),
};
let res_m = p.mul_scalar(3.to_bigint().unwrap());
let res_a = p.add(&p);
let res_a = res_a.add(&p);
let res_m = p.mul_scalar(3.to_bigint().unwrap()).unwrap();
let res_a = p.add(&p).unwrap();
let res_a = res_a.add(&p).unwrap();
assert_eq!(res_m.x, res_a.x);
assert_eq!(
res_m.x.to_string(),
@ -434,7 +443,7 @@ mod tests {
10,
)
.unwrap();
let res2 = p.mul_scalar(n);
let res2 = p.mul_scalar(n).unwrap();
assert_eq!(
res2.x.to_string(),
"17070357974431721403481313912716834497662307308519659060910483826664480189605"
@ -448,7 +457,7 @@ mod tests {
#[test]
fn test_new_key_sign_verify_mimc_0() {
let sk = new_key();
let pk = sk.public();
let pk = sk.public().unwrap();
let msg = 5.to_bigint().unwrap();
let sig = sk.sign_mimc(msg.clone()).unwrap();
let v = verify_mimc(pk, sig, msg);
@ -458,7 +467,7 @@ mod tests {
#[test]
fn test_new_key_sign_verify_mimc_1() {
let sk = new_key();
let pk = sk.public();
let pk = sk.public().unwrap();
let msg = BigInt::parse_bytes(b"123456789012345678901234567890", 10).unwrap();
let sig = sk.sign_mimc(msg.clone()).unwrap();
let v = verify_mimc(pk, sig, msg);
@ -467,7 +476,7 @@ mod tests {
#[test]
fn test_new_key_sign_verify_poseidon_0() {
let sk = new_key();
let pk = sk.public();
let pk = sk.public().unwrap();
let msg = 5.to_bigint().unwrap();
let sig = sk.sign_poseidon(msg.clone()).unwrap();
let v = verify_poseidon(pk, sig, msg);
@ -477,7 +486,7 @@ mod tests {
#[test]
fn test_new_key_sign_verify_poseidon_1() {
let sk = new_key();
let pk = sk.public();
let pk = sk.public().unwrap();
let msg = BigInt::parse_bytes(b"123456789012345678901234567890", 10).unwrap();
let sig = sk.sign_poseidon(msg.clone()).unwrap();
let v = verify_poseidon(pk, sig, msg);
@ -559,7 +568,7 @@ mod tests {
h[31] = h[31] | 0x40;
let sk = BigInt::from_bytes_le(Sign::Plus, &h[..]);
let point = B8.mul_scalar(sk.clone());
let point = B8.mul_scalar(sk.clone()).unwrap();
let cmp_point = point.compress();
let dcmp_point = decompress_point(cmp_point).unwrap();
@ -571,7 +580,7 @@ mod tests {
#[test]
fn test_signature_compress_decompress() {
let sk = new_key();
let pk = sk.public();
let pk = sk.public().unwrap();
for i in 0..5 {
let msg_raw = "123456".to_owned() + &i.to_string();

+ 25
- 20
src/utils.rs

@ -9,11 +9,15 @@ pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt {
((a % m) + m) % m
}
pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt {
pub fn modinv(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
let big_zero: BigInt = Zero::zero();
if a == &big_zero {
return Err("no mod inv of Zero".to_string());
}
let mut mn = (q.clone(), a.clone());
let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one());
let big_zero: BigInt = Zero::zero();
while mn.1 != big_zero {
xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1);
mn = (mn.1.clone(), modulus(&mn.0, &mn.1));
@ -22,7 +26,7 @@ pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt {
while xy.0 < Zero::zero() {
xy.0 = modulus(&xy.0, q);
}
xy.0
Ok(xy.0)
}
/*
@ -102,7 +106,7 @@ pub fn concatenate_arrays(x: &[T], y: &[T]) -> Vec {
x.iter().chain(y).cloned().collect()
}
pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt {
pub fn modsqrt(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
// Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
//
// This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859
@ -112,15 +116,14 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt {
let zero: BigInt = Zero::zero();
let one: BigInt = One::one();
if legendre_symbol(&a, q) != 1 {
// not a mod p square
return zero;
return Err("not a mod p square".to_string());
} else if a == &zero {
return zero;
return Err("not a mod p square".to_string());
} else if q == &2.to_bigint().unwrap() {
return zero;
return Err("not a mod p square".to_string());
} else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
let r = a.modpow(&((q + one) / 4), &q);
return r;
return Ok(r);
}
let mut s = q - &one;
@ -149,7 +152,7 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt {
}
if m == zero {
return y.clone();
return Ok(y.clone());
}
t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q);
@ -161,7 +164,7 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt {
}
#[allow(dead_code)]
pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt {
pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
// Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
//
// This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py
@ -169,15 +172,14 @@ pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt {
let zero: BigInt = Zero::zero();
let one: BigInt = One::one();
if legendre_symbol(&a, q) != 1 {
// not a mod p square
return zero;
return Err("not a mod p square".to_string());
} else if a == &zero {
return zero;
return Err("not a mod p square".to_string());
} else if q == &2.to_bigint().unwrap() {
return zero;
return Err("not a mod p square".to_string());
} else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
let r = a.modpow(&((q + one) / 4), &q);
return r;
return Ok(r);
}
let mut p = q - &one;
@ -214,7 +216,7 @@ pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt {
c = modulus(&(&b * &b), q);
m = i.clone();
}
return x;
return Ok(x);
}
pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 {
@ -235,7 +237,10 @@ mod tests {
fn test_mod_inverse() {
let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap();
let b = BigInt::parse_bytes(b"12345678", 10).unwrap();
assert_eq!(modinv(&a, &b), BigInt::parse_bytes(b"641883", 10).unwrap());
assert_eq!(
modinv(&a, &b).unwrap(),
BigInt::parse_bytes(b"641883", 10).unwrap()
);
}
#[test]
@ -252,11 +257,11 @@ mod tests {
.unwrap();
assert_eq!(
(modsqrt(&a, &q)).to_string(),
(modsqrt(&a, &q).unwrap()).to_string(),
"5464794816676661649783249706827271879994893912039750480019443499440603127256"
);
assert_eq!(
(modsqrt_v2(&a, &q)).to_string(),
(modsqrt_v2(&a, &q).unwrap()).to_string(),
"5464794816676661649783249706827271879994893912039750480019443499440603127256"
);
}

Loading…
Cancel
Save