From addcca64e55687e7ff7ea0e72075fe08a4c3268e Mon Sep 17 00:00:00 2001 From: arnaucube Date: Mon, 9 Sep 2019 20:42:17 +0200 Subject: [PATCH] add some error handling --- Cargo.toml | 2 +- README.md | 7 ++-- src/lib.rs | 111 ++++++++++++++++++++++++++++----------------------- src/utils.rs | 45 +++++++++++---------- 4 files changed, 90 insertions(+), 75 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 66cbd29..5f3d28e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "babyjubjub-rs" -version = "0.0.2" +version = "0.0.3" authors = ["arnaucube "] edition = "2018" license = "GPL-3.0" diff --git a/README.md b/README.md index 376490b..d5dd137 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # babyjubjub-rs [![Crates.io](https://img.shields.io/crates/v/babyjubjub-rs.svg)](https://crates.io/crates/babyjubjub-rs) [![Build Status](https://travis-ci.org/arnaucube/babyjubjub-rs.svg?branch=master)](https://travis-ci.org/arnaucube/babyjubjub-rs) -BabyJubJub elliptic curve implementation in Rust. Is a twisted edwards curve embedded in the curve of BN128. +BabyJubJub elliptic curve implementation in Rust. A twisted edwards curve embedded in the curve of BN128. BabyJubJub curve explanation: https://medium.com/zokrates/efficient-ecc-in-zksnarks-using-zokrates-bd9ae37b8186 @@ -10,7 +10,7 @@ Uses: Compatible with the BabyJubJub Go implementation from https://github.com/iden3/go-iden3-crypto ## Warning -Doing this in my free time to get familiar with Rust, do not use in production. +Doing this in my free time to get familiar with Rust, **do not use in production**. - [x] point addition - [x] point scalar multiplication @@ -24,7 +24,8 @@ Doing this in my free time to get familiar with Rust, do not use in production. ### References - BabyJubJub curve explanation: https://medium.com/zokrates/efficient-ecc-in-zksnarks-using-zokrates-bd9ae37b8186 - - C++ https://github.com/barryWhiteHat/baby_jubjub_ecc + - C++ & Explanation https://github.com/barryWhiteHat/baby_jubjub + - C++ https://github.com/barryWhiteHat/baby_jubjub_ecc - Javascript & Circom: https://github.com/iden3/circomlib - Go https://github.com/iden3/go-iden3-crypto - JubJub curve explanation: https://z.cash/technology/jubjub/ diff --git a/src/lib.rs b/src/lib.rs index 978a6e3..2b489cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,7 +62,7 @@ pub struct Point { } impl Point { - pub fn add(&self, q: &Point) -> Point { + pub fn add(&self, q: &Point) -> Result { // x = (x1*y2+y1*x2)/(c*(1+d*x1*x2*y1*y2)) // y = (y1*y2-x1*x2)/(c*(1-d*x1*x2*y1*y2)) @@ -70,19 +70,19 @@ impl Point { let one: BigInt = One::one(); let x_num: BigInt = &self.x * &q.y + &self.y * &q.x; let x_den: BigInt = &one + &D.clone() * &self.x * &q.x * &self.y * &q.y; - let x_den_inv = utils::modinv(&x_den, &Q); + let x_den_inv = utils::modinv(&x_den, &Q)?; let x: BigInt = utils::modulus(&(&x_num * &x_den_inv), &Q); // y = (y1 * y2 - a * x1 * x2) / (1 - d * x1 * x2 * y1 * y2) let y_num = &self.y * &q.y - &A.clone() * &self.x * &q.x; let y_den = utils::modulus(&(&one - &D.clone() * &self.x * &q.x * &self.y * &q.y), &Q); - let y_den_inv = utils::modinv(&y_den, &Q); + let y_den_inv = utils::modinv(&y_den, &Q)?; let y: BigInt = utils::modulus(&(&y_num * &y_den_inv), &Q); - Point { x: x, y: y } + Ok(Point { x: x, y: y }) } - pub fn mul_scalar(&self, n: BigInt) -> Point { + pub fn mul_scalar(&self, n: BigInt) -> Result { // TODO use & in n to avoid clones on function call let mut r: Point = Point { x: Zero::zero(), @@ -96,14 +96,14 @@ impl Point { while rem != zero { let is_odd = &rem & &one == one; if is_odd == true { - r = r.add(&exp); + r = r.add(&exp)?; } - exp = exp.add(&exp); + exp = exp.add(&exp)?; rem = rem >> 1; } r.x = utils::modulus(&r.x, &Q); r.y = utils::modulus(&r.y, &Q); - r + Ok(r) } pub fn compress(&self) -> [u8; 32] { @@ -111,7 +111,7 @@ impl Point { let (_, y_bytes) = self.y.to_bytes_le(); let len = min(y_bytes.len(), r.len()); r[..len].copy_from_slice(&y_bytes[..len]); - if &self.x >= &(&Q.clone() >> 1) { + if &self.x > &(&Q.clone() >> 1) { r[31] = r[31] | 0x80; } r @@ -133,18 +133,15 @@ pub fn decompress_point(bb: [u8; 32]) -> Result { let one: BigInt = One::one(); // x^2 = (1 - y^2) / (a - d * y^2) (mod p) - let mut x: BigInt = utils::modulus( - &((one - utils::modulus(&(&y * &y), &Q)) - * utils::modinv( - &utils::modulus( - &(&A.clone() - utils::modulus(&(&D.clone() * (&y * &y)), &Q)), - &Q, - ), - &Q, - )), + let den = utils::modinv( + &utils::modulus( + &(&A.clone() - utils::modulus(&(&D.clone() * (&y * &y)), &Q)), + &Q, + ), &Q, - ); - x = utils::modsqrt(&x, &Q); + )?; + let mut x: BigInt = utils::modulus(&((one - utils::modulus(&(&y * &y), &Q)) * den), &Q); + x = utils::modsqrt(&x, &Q)?; if sign && !(&x > &(&Q.clone() >> 1)) || (!sign && (&x > &(&Q.clone() >> 1))) { x = x * -1.to_bigint().unwrap(); @@ -191,10 +188,10 @@ pub struct PrivateKey { } impl PrivateKey { - pub fn public(&self) -> Point { + pub fn public(&self) -> Result { // https://tools.ietf.org/html/rfc8032#section-5.1.5 - let pk = B8.mul_scalar(self.key.clone()); - pk.clone() + let pk = B8.mul_scalar(self.key.clone())?; + Ok(pk.clone()) } pub fn sign_mimc(&self, msg: BigInt) -> Result { @@ -209,15 +206,12 @@ impl PrivateKey { let r_bytes = utils::concatenate_arrays(s, &msg_bytes); let mut r = BigInt::from_bytes_be(Sign::Plus, &r_bytes[..]); r = utils::modulus(&r, &SUBORDER); - let r8: Point = B8.mul_scalar(r.clone()); - let a = &self.public(); + let r8: Point = B8.mul_scalar(r.clone())?; + let a = &self.public()?; let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg]; let mimc7 = Mimc7::new(); - let hm = match mimc7.hash(hm_input) { - Result::Err(err) => return Err(err.to_string()), - Result::Ok(hm) => hm, - }; + let hm = mimc7.hash(hm_input)?; let mut s = &self.key << 3; s = hm * s; @@ -241,15 +235,12 @@ impl PrivateKey { let r_bytes = utils::concatenate_arrays(s, &msg_bytes); let mut r = BigInt::from_bytes_be(Sign::Plus, &r_bytes[..]); r = utils::modulus(&r, &SUBORDER); - let r8: Point = B8.mul_scalar(r.clone()); - let a = &self.public(); + let r8: Point = B8.mul_scalar(r.clone())?; + let a = &self.public()?; let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg]; let poseidon = Poseidon::new(); - let hm = match poseidon.hash(hm_input) { - Result::Err(err) => return Err(err.to_string()), - Result::Ok(hm) => hm, - }; + let hm = poseidon.hash(hm_input)?; let mut s = &self.key << 3; s = hm * s; @@ -295,8 +286,17 @@ pub fn verify_mimc(pk: Point, sig: Signature, msg: BigInt) -> bool { Result::Err(_) => return false, Result::Ok(hm) => hm, }; - let l = B8.mul_scalar(sig.s); - let r = sig.r_b8.add(&pk.mul_scalar(8.to_bigint().unwrap() * hm)); + let l = match B8.mul_scalar(sig.s) { + Result::Err(_) => return false, + Result::Ok(l) => l, + }; + let r = match sig + .r_b8 + .add(&pk.mul_scalar(8.to_bigint().unwrap() * hm).unwrap()) + { + Result::Err(_) => return false, + Result::Ok(r) => r, + }; if l.x == r.x && l.y == r.y { return true; } @@ -315,8 +315,17 @@ pub fn verify_poseidon(pk: Point, sig: Signature, msg: BigInt) -> bool { Result::Err(_) => return false, Result::Ok(hm) => hm, }; - let l = B8.mul_scalar(sig.s); - let r = sig.r_b8.add(&pk.mul_scalar(8.to_bigint().unwrap() * hm)); + let l = match B8.mul_scalar(sig.s) { + Result::Err(_) => return false, + Result::Ok(l) => l, + }; + let r = match sig + .r_b8 + .add(&pk.mul_scalar(8.to_bigint().unwrap() * hm).unwrap()) + { + Result::Err(_) => return false, + Result::Ok(r) => r, + }; if l.x == r.x && l.y == r.y { return true; } @@ -355,7 +364,7 @@ mod tests { ) .unwrap(), }; - let res = p.add(&q); + let res = p.add(&q).unwrap(); assert_eq!( res.x.to_string(), "6890855772600357754907169075114257697580319025794532037257385534741338397365" @@ -391,7 +400,7 @@ mod tests { ) .unwrap(), }; - let res = p.add(&q); + let res = p.add(&q).unwrap(); assert_eq!( res.x.to_string(), "7916061937171219682591368294088513039687205273691143098332585753343424131937" @@ -416,9 +425,9 @@ mod tests { ) .unwrap(), }; - let res_m = p.mul_scalar(3.to_bigint().unwrap()); - let res_a = p.add(&p); - let res_a = res_a.add(&p); + let res_m = p.mul_scalar(3.to_bigint().unwrap()).unwrap(); + let res_a = p.add(&p).unwrap(); + let res_a = res_a.add(&p).unwrap(); assert_eq!(res_m.x, res_a.x); assert_eq!( res_m.x.to_string(), @@ -434,7 +443,7 @@ mod tests { 10, ) .unwrap(); - let res2 = p.mul_scalar(n); + let res2 = p.mul_scalar(n).unwrap(); assert_eq!( res2.x.to_string(), "17070357974431721403481313912716834497662307308519659060910483826664480189605" @@ -448,7 +457,7 @@ mod tests { #[test] fn test_new_key_sign_verify_mimc_0() { let sk = new_key(); - let pk = sk.public(); + let pk = sk.public().unwrap(); let msg = 5.to_bigint().unwrap(); let sig = sk.sign_mimc(msg.clone()).unwrap(); let v = verify_mimc(pk, sig, msg); @@ -458,7 +467,7 @@ mod tests { #[test] fn test_new_key_sign_verify_mimc_1() { let sk = new_key(); - let pk = sk.public(); + let pk = sk.public().unwrap(); let msg = BigInt::parse_bytes(b"123456789012345678901234567890", 10).unwrap(); let sig = sk.sign_mimc(msg.clone()).unwrap(); let v = verify_mimc(pk, sig, msg); @@ -467,7 +476,7 @@ mod tests { #[test] fn test_new_key_sign_verify_poseidon_0() { let sk = new_key(); - let pk = sk.public(); + let pk = sk.public().unwrap(); let msg = 5.to_bigint().unwrap(); let sig = sk.sign_poseidon(msg.clone()).unwrap(); let v = verify_poseidon(pk, sig, msg); @@ -477,7 +486,7 @@ mod tests { #[test] fn test_new_key_sign_verify_poseidon_1() { let sk = new_key(); - let pk = sk.public(); + let pk = sk.public().unwrap(); let msg = BigInt::parse_bytes(b"123456789012345678901234567890", 10).unwrap(); let sig = sk.sign_poseidon(msg.clone()).unwrap(); let v = verify_poseidon(pk, sig, msg); @@ -559,7 +568,7 @@ mod tests { h[31] = h[31] | 0x40; let sk = BigInt::from_bytes_le(Sign::Plus, &h[..]); - let point = B8.mul_scalar(sk.clone()); + let point = B8.mul_scalar(sk.clone()).unwrap(); let cmp_point = point.compress(); let dcmp_point = decompress_point(cmp_point).unwrap(); @@ -571,7 +580,7 @@ mod tests { #[test] fn test_signature_compress_decompress() { let sk = new_key(); - let pk = sk.public(); + let pk = sk.public().unwrap(); for i in 0..5 { let msg_raw = "123456".to_owned() + &i.to_string(); diff --git a/src/utils.rs b/src/utils.rs index bc2d293..3763040 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -9,11 +9,15 @@ pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt { ((a % m) + m) % m } -pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt { +pub fn modinv(a: &BigInt, q: &BigInt) -> Result { + let big_zero: BigInt = Zero::zero(); + if a == &big_zero { + return Err("no mod inv of Zero".to_string()); + } + let mut mn = (q.clone(), a.clone()); let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one()); - let big_zero: BigInt = Zero::zero(); while mn.1 != big_zero { xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1); mn = (mn.1.clone(), modulus(&mn.0, &mn.1)); @@ -22,7 +26,7 @@ pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt { while xy.0 < Zero::zero() { xy.0 = modulus(&xy.0, q); } - xy.0 + Ok(xy.0) } /* @@ -102,7 +106,7 @@ pub fn concatenate_arrays(x: &[T], y: &[T]) -> Vec { x.iter().chain(y).cloned().collect() } -pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt { +pub fn modsqrt(a: &BigInt, q: &BigInt) -> Result { // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm) // // This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859 @@ -112,15 +116,14 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt { let zero: BigInt = Zero::zero(); let one: BigInt = One::one(); if legendre_symbol(&a, q) != 1 { - // not a mod p square - return zero; + return Err("not a mod p square".to_string()); } else if a == &zero { - return zero; + return Err("not a mod p square".to_string()); } else if q == &2.to_bigint().unwrap() { - return zero; + return Err("not a mod p square".to_string()); } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { let r = a.modpow(&((q + one) / 4), &q); - return r; + return Ok(r); } let mut s = q - &one; @@ -149,7 +152,7 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt { } if m == zero { - return y.clone(); + return Ok(y.clone()); } t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q); @@ -161,7 +164,7 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt { } #[allow(dead_code)] -pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt { +pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> Result { // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm) // // This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py @@ -169,15 +172,14 @@ pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt { let zero: BigInt = Zero::zero(); let one: BigInt = One::one(); if legendre_symbol(&a, q) != 1 { - // not a mod p square - return zero; + return Err("not a mod p square".to_string()); } else if a == &zero { - return zero; + return Err("not a mod p square".to_string()); } else if q == &2.to_bigint().unwrap() { - return zero; + return Err("not a mod p square".to_string()); } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { let r = a.modpow(&((q + one) / 4), &q); - return r; + return Ok(r); } let mut p = q - &one; @@ -214,7 +216,7 @@ pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt { c = modulus(&(&b * &b), q); m = i.clone(); } - return x; + return Ok(x); } pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 { @@ -235,7 +237,10 @@ mod tests { fn test_mod_inverse() { let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap(); let b = BigInt::parse_bytes(b"12345678", 10).unwrap(); - assert_eq!(modinv(&a, &b), BigInt::parse_bytes(b"641883", 10).unwrap()); + assert_eq!( + modinv(&a, &b).unwrap(), + BigInt::parse_bytes(b"641883", 10).unwrap() + ); } #[test] @@ -252,11 +257,11 @@ mod tests { .unwrap(); assert_eq!( - (modsqrt(&a, &q)).to_string(), + (modsqrt(&a, &q).unwrap()).to_string(), "5464794816676661649783249706827271879994893912039750480019443499440603127256" ); assert_eq!( - (modsqrt_v2(&a, &q)).to_string(), + (modsqrt_v2(&a, &q).unwrap()).to_string(), "5464794816676661649783249706827271879994893912039750480019443499440603127256" ); }