diff --git a/Cargo.toml b/Cargo.toml index 8b15289..040ef95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ generic-array = "0.13.2" tiny-keccak = "1.5" rustc-hex = "1.0.0" mimc-rs = "0.0.1" +arrayref = "0.3.5" diff --git a/README.md b/README.md index 8979f11..920bd38 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ BabyJubJub elliptic curve implementation in Rust Uses MiMC7 hash function: https://github.com/arnaucube/mimc-rs ## Warning -Doing this in my free time to get familiar with Rust, do not use in production +Doing this in my free time to get familiar with Rust, do not use in production. - [x] point addition - [x] point scalar multiplication diff --git a/src/lib.rs b/src/lib.rs index e789c0e..19bd79b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#[macro_use] +extern crate arrayref; extern crate generic_array; extern crate mimc_rs; extern crate num; @@ -8,9 +10,7 @@ extern crate rand; use blake2::{Blake2b, Digest}; use mimc_rs::Mimc7; -use num_bigint::RandBigInt; - -use num_bigint::{BigInt, Sign, ToBigInt}; +use num_bigint::{BigInt, RandBigInt, Sign, ToBigInt}; use num_traits::{One, Zero}; use generic_array::GenericArray; @@ -22,11 +22,56 @@ pub struct Point { pub x: BigInt, pub y: BigInt, } + pub struct Signature { r_b8: Point, s: BigInt, } +pub struct PrivateKey { + bbjj: Babyjubjub, + key: BigInt, +} + +impl PrivateKey { + pub fn public(&self) -> Point { + // https://tools.ietf.org/html/rfc8032#section-5.1.5 + let pk = &self.bbjj.mul_scalar(self.bbjj.b8.clone(), self.key.clone()); + pk.clone() + } + + pub fn sign(&self, msg: BigInt) -> Signature { + // https://tools.ietf.org/html/rfc8032#section-5.1.6 + let mut hasher = Blake2b::new(); + let (_, sk_bytes) = self.key.to_bytes_be(); + hasher.input(sk_bytes); + let mut h = hasher.result(); // h: hash(sk) + // s: h[32:64] + let s = GenericArray::::from_mut_slice(&mut h[32..64]); + let (_, msg_bytes) = msg.to_bytes_be(); + let r_bytes = utils::concatenate_arrays(s, &msg_bytes); + let mut r = BigInt::from_bytes_be(Sign::Plus, &r_bytes[..]); + r = utils::modulus(&r, &self.bbjj.sub_order); + let r8: Point = self.bbjj.mul_scalar(self.bbjj.b8.clone(), r.clone()); + // let a = &self.sk_to_pk(sk.clone()); + let a = &self.public(); + + let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg]; + let mimc7 = Mimc7::new(); + let hm = mimc7.hash(hm_input); + + let mut s = &self.key << 3; + s = hm * s; + s = r + s; + s = s % &self.bbjj.sub_order; + + Signature { + r_b8: r8.clone(), + s: s, + } + } +} + pub struct Babyjubjub { d: BigInt, a: BigInt, @@ -82,17 +127,13 @@ impl Babyjubjub { let one: BigInt = One::one(); let x_num: BigInt = &p.x * &q.y + &p.y * &q.x; let x_den: BigInt = &one + &self.d * &p.x * &q.x * &p.y * &q.y; - let x_den_inv = utils::mod_inverse0(&x_den, &self.q); - // let x_den_inv = utils::mod_inverse1(x_den, self.q.clone()); - // let x_den_inv = utils::mod_inverse2(x_den, self.q.clone()); + let x_den_inv = utils::modinv(&x_den, &self.q); let x: BigInt = utils::modulus(&(&x_num * &x_den_inv), &self.q); // y = (y1 * y2 - a * x1 * x2) / (1 - d * x1 * x2 * y1 * y2) let y_num = &p.y * &q.y - &self.a * &p.x * &q.x; let y_den = utils::modulus(&(&one - &self.d * &p.x * &q.x * &p.y * &q.y), &self.q); - let y_den_inv = utils::mod_inverse0(&y_den, &self.q); - // let y_den_inv = utils::mod_inverse1(y_den, self.q.clone()); - // let y_den_inv = utils::mod_inverse2(y_den, self.q.clone()); + let y_den_inv = utils::modinv(&y_den, &self.q); let y: BigInt = utils::modulus(&(&y_num * &y_den_inv), &self.q); Point { x: x, y: y } @@ -122,7 +163,52 @@ impl Babyjubjub { r } - pub fn new_key(&self) -> BigInt { + pub fn compress(&self, p: Point) -> [u8; 32] { + let mut r: [u8; 32]; + let (_, y_bytes) = p.y.to_bytes_le(); + r = *array_ref!(y_bytes, 0, 32); + if &p.x > &(&self.q >> 1) { + r[31] = r[31] | 0x80; + } + r + } + + pub fn decompress_point(&self, bb: [u8; 32]) -> Point { + // https://tools.ietf.org/html/rfc8032#section-5.2.3 + let mut sign: bool = false; + let mut b = bb.clone(); + if b[31] & 0x80 != 0x00 { + sign = true; + b[31] = b[31] & 0x7F; + } + let y: BigInt = BigInt::from_bytes_le(Sign::Plus, &b[..]); + if y >= self.q { + // println!("ERROR0"); + } + let one: BigInt = One::one(); + + // x^2 = (1 - y^2) / (a - d * y^2) (mod p) + let mut x: BigInt = utils::modulus( + &((one - utils::modulus(&(&y * &y), &self.q)) + * utils::modinv( + &utils::modulus( + &(&self.a - utils::modulus(&(&self.d * (&y * &y)), &self.q)), + &self.q, + ), + &self.q, + )), + &self.q, + ); + x = utils::modsqrt(&x, &self.q); + + if (sign && x >= Zero::zero()) || (!sign && x < Zero::zero()) { + x = x * -1.to_bigint().unwrap(); + } + x = utils::modulus(&x, &self.q); + Point { x: x, y: y } + } + + pub fn new_key(&self) -> PrivateKey { // https://tools.ietf.org/html/rfc8032#section-5.1.5 let mut rng = rand::thread_rng(); let sk_raw = rng.gen_biguint(1024).to_bigint().unwrap(); @@ -138,43 +224,16 @@ impl Babyjubjub { let sk = BigInt::from_bytes_le(Sign::Plus, &h[..]); - sk - } - - pub fn sk_to_pk(&self, sk: BigInt) -> Point { - // https://tools.ietf.org/html/rfc8032#section-5.1.5 - // TODO this will be moved into a method of PrivateKey type - let pk = &self.mul_scalar(self.b8.clone(), sk); - pk.clone() - } - - pub fn sign(&self, sk: BigInt, msg: BigInt) -> Signature { - // https://tools.ietf.org/html/rfc8032#section-5.1.6 - let mut hasher = Blake2b::new(); - let (_, sk_bytes) = sk.to_bytes_be(); - hasher.input(sk_bytes); - let mut h = hasher.result(); // h: hash(sk) - // s: h[32:64] - let s = GenericArray::::from_mut_slice(&mut h[32..64]); - let (_, msg_bytes) = msg.to_bytes_be(); - let r_bytes = utils::concatenate_arrays(s, &msg_bytes); - let mut r = BigInt::from_bytes_be(Sign::Plus, &r_bytes[..]); - r = utils::modulus(&r, &self.sub_order); - let r8: Point = self.mul_scalar(self.b8.clone(), r.clone()); - let a = &self.sk_to_pk(sk.clone()); - - let hm_input = vec![r8.x.clone(), r8.y.clone(), a.x.clone(), a.y.clone(), msg]; - let mimc7 = Mimc7::new(); - let hm = mimc7.hash(hm_input); - - let mut s = sk << 3; - s = hm * s; - s = r + s; - s = s % &self.sub_order; - - Signature { - r_b8: r8.clone(), - s: s, + let bbjj_new = Babyjubjub { + d: self.d.clone(), + a: self.a.clone(), + q: self.q.clone(), + b8: self.b8.clone(), + sub_order: self.sub_order.clone(), + }; + PrivateKey { + bbjj: bbjj_new, + key: sk, } } @@ -200,6 +259,8 @@ impl Babyjubjub { #[cfg(test)] mod tests { use super::*; + extern crate rustc_hex; + use rustc_hex::ToHex; #[test] fn test_add_same_point() { @@ -321,12 +382,48 @@ mod tests { } #[test] - fn test_new_key_sign_verify() { + fn test_point_compress_decompress() { + let bbjj = Babyjubjub::new(); + let p: Point = Point { + x: BigInt::parse_bytes( + b"17777552123799933955779906779655732241715742912184938656739573121738514868268", + 10, + ) + .unwrap(), + y: BigInt::parse_bytes( + b"2626589144620713026669568689430873010625803728049924121243784502389097019475", + 10, + ) + .unwrap(), + }; + let p_comp = bbjj.compress(p.clone()); + assert_eq!( + p_comp[..].to_hex(), + "53b81ed5bffe9545b54016234682e7b2f699bd42a5e9eae27ff4051bc698ce85" + ); + let p2 = bbjj.decompress_point(p_comp); + assert_eq!(p.x, p2.x); + assert_eq!(p.y, p2.y); + } + + #[test] + fn test_new_key_sign_verify0() { let bbjj = Babyjubjub::new(); let sk = bbjj.new_key(); - let pk = bbjj.sk_to_pk(sk.clone()); + let pk = sk.public(); let msg = 5.to_bigint().unwrap(); - let sig = bbjj.sign(sk, msg.clone()); + let sig = sk.sign(msg.clone()); + let v = bbjj.verify(pk, sig, msg); + assert_eq!(v, true); + } + + #[test] + fn test_new_key_sign_verify1() { + let bbjj = Babyjubjub::new(); + let sk = bbjj.new_key(); + let pk = sk.public(); + let msg = BigInt::parse_bytes(b"123456789012345678901234567890", 10).unwrap(); + let sig = sk.sign(msg.clone()); let v = bbjj.verify(pk, sig, msg); assert_eq!(v, true); } diff --git a/src/utils.rs b/src/utils.rs index 8f8bed5..bc2d293 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,14 +2,14 @@ extern crate num; extern crate num_bigint; extern crate num_traits; -use num_bigint::BigInt; +use num_bigint::{BigInt, ToBigInt}; use num_traits::{One, Zero}; pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt { ((a % m) + m) % m } -pub fn mod_inverse0(a: &BigInt, q: &BigInt) -> BigInt { +pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt { let mut mn = (q.clone(), a.clone()); let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one()); @@ -26,13 +26,13 @@ pub fn mod_inverse0(a: &BigInt, q: &BigInt) -> BigInt { } /* -pub fn mod_inverse1(a0: BigInt, m0: BigInt) -> BigInt { - if m0 == One::one() { +pub fn modinv_v2(a0: &BigInt, m0: &BigInt) -> BigInt { + if m0 == &One::one() { return One::one(); } let (mut a, mut m, mut x0, mut inv): (BigInt, BigInt, BigInt, BigInt) = - (a0, m0.clone(), Zero::zero(), One::one()); + (a0.clone(), m0.clone(), Zero::zero(), One::one()); while a > One::one() { inv = inv - (&a / m.clone()) * x0.clone(); @@ -47,9 +47,9 @@ pub fn mod_inverse1(a0: BigInt, m0: BigInt) -> BigInt { inv } -pub fn mod_inverse2(a: BigInt, q: BigInt) -> BigInt { - let mut aa: BigInt = a; - let mut qq: BigInt = q; +pub fn modinv_v3(a: &BigInt, q: &BigInt) -> BigInt { + let mut aa: BigInt = a.clone(); + let mut qq: BigInt = q.clone(); if qq < Zero::zero() { qq = -qq; } @@ -68,12 +68,165 @@ pub fn mod_inverse2(a: BigInt, q: BigInt) -> BigInt { } res } +pub fn modinv_v4(x: &BigInt, q: &BigInt) -> BigInt { + let (gcd, inverse, _) = extended_gcd(x.clone(), q.clone()); + let one: BigInt = One::one(); + if gcd == one { + modulus(&inverse, q) + } else { + panic!("error: gcd!=one") + } +} +pub fn extended_gcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) { + let (mut s, mut old_s) = (BigInt::zero(), BigInt::one()); + let (mut t, mut old_t) = (BigInt::one(), BigInt::zero()); + let (mut r, mut old_r) = (b, a); + + while r != BigInt::zero() { + let quotient = &old_r / &r; + old_r -= "ient * &r; + std::mem::swap(&mut old_r, &mut r); + old_s -= "ient * &s; + std::mem::swap(&mut old_s, &mut s); + old_t -= quotient * &t; + std::mem::swap(&mut old_t, &mut t); + } + + let _quotients = (t, s); // == (a, b) / gcd + + (old_r, old_s, old_t) +} */ pub fn concatenate_arrays(x: &[T], y: &[T]) -> Vec { x.iter().chain(y).cloned().collect() } +pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt { + // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm) + // + // This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859 + // Also described in https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + // -> section 6 + + let zero: BigInt = Zero::zero(); + let one: BigInt = One::one(); + if legendre_symbol(&a, q) != 1 { + // not a mod p square + return zero; + } else if a == &zero { + return zero; + } else if q == &2.to_bigint().unwrap() { + return zero; + } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { + let r = a.modpow(&((q + one) / 4), &q); + return r; + } + + let mut s = q - &one; + let mut e: BigInt = Zero::zero(); + while &s % 2 == zero { + s = s >> 1; + e = e + &one; + } + + let mut n: BigInt = 2.to_bigint().unwrap(); + while legendre_symbol(&n, q) != -1 { + n = &n + &one; + } + + let mut y = a.modpow(&((&s + &one) >> 1), q); + let mut b = a.modpow(&s, q); + let mut g = n.modpow(&s, q); + let mut r = e; + + loop { + let mut t = b.clone(); + let mut m: BigInt = Zero::zero(); + while &t != &one { + t = modulus(&(&t * &t), q); + m = m + &one; + } + + if m == zero { + return y.clone(); + } + + t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q); + g = g.modpow(&(2.to_bigint().unwrap().modpow(&(r - &m), q)), q); + y = modulus(&(y * t), q); + b = modulus(&(b * &g), q); + r = m.clone(); + } +} + +#[allow(dead_code)] +pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt { + // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm) + // + // This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py + + let zero: BigInt = Zero::zero(); + let one: BigInt = One::one(); + if legendre_symbol(&a, q) != 1 { + // not a mod p square + return zero; + } else if a == &zero { + return zero; + } else if q == &2.to_bigint().unwrap() { + return zero; + } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { + let r = a.modpow(&((q + one) / 4), &q); + return r; + } + + let mut p = q - &one; + let mut s: BigInt = Zero::zero(); + while &p % 2.to_bigint().unwrap() == zero { + s = s + &one; + p = p >> 1; + } + + let mut z: BigInt = One::one(); + while legendre_symbol(&z, q) != -1 { + z = &z + &one; + } + let mut c = z.modpow(&p, q); + + let mut x = a.modpow(&((&p + &one) >> 1), q); + let mut t = a.modpow(&p, q); + let mut m = s; + + while &t != &one { + let mut i: BigInt = One::one(); + let mut e: BigInt = 2.to_bigint().unwrap(); + while i < m { + if t.modpow(&e, q) == one { + break; + } + e = e * 2.to_bigint().unwrap(); + i = i + &one; + } + + let b = c.modpow(&(2.to_bigint().unwrap().modpow(&(&m - &i - 1), q)), q); + x = modulus(&(x * &b), q); + t = modulus(&(t * &b * &b), q); + c = modulus(&(&b * &b), q); + m = i.clone(); + } + return x; +} + +pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 { + // returns 1 if has a square root modulo q + let one: BigInt = One::one(); + let ls: BigInt = a.modpow(&((q - &one) >> 1), &q); + if &(ls) == &(q - one) { + return -1; + } + 1 +} + #[cfg(test)] mod tests { use super::*; @@ -82,9 +235,29 @@ mod tests { fn test_mod_inverse() { let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap(); let b = BigInt::parse_bytes(b"12345678", 10).unwrap(); + assert_eq!(modinv(&a, &b), BigInt::parse_bytes(b"641883", 10).unwrap()); + } + + #[test] + fn test_sqrtmod() { + let a = BigInt::parse_bytes( + b"6536923810004159332831702809452452174451353762940761092345538667656658715568", + 10, + ) + .unwrap(); + let q = BigInt::parse_bytes( + b"7237005577332262213973186563042994240857116359379907606001950938285454250989", + 10, + ) + .unwrap(); + + assert_eq!( + (modsqrt(&a, &q)).to_string(), + "5464794816676661649783249706827271879994893912039750480019443499440603127256" + ); assert_eq!( - mod_inverse0(&a, &b), - BigInt::parse_bytes(b"641883", 10).unwrap() + (modsqrt_v2(&a, &q)).to_string(), + "5464794816676661649783249706827271879994893912039750480019443499440603127256" ); } }