diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 58b7ded..d6d0796 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,4 +12,4 @@ jobs: - name: Run tests run: | cargo test --verbose - cargo test --verbose --no-default-features --features=aarch64 + cargo test --verbose --features=aarch64 diff --git a/Cargo.toml b/Cargo.toml index d5792e5..3c87354 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "babyjubjub-rs" -version = "0.0.9" +version = "0.0.10" authors = ["arnaucube "] edition = "2021" license = "GPL-3.0" diff --git a/benches/bench_babyjubjub.rs b/benches/bench_babyjubjub.rs index 67218bb..ed04df9 100644 --- a/benches/bench_babyjubjub.rs +++ b/benches/bench_babyjubjub.rs @@ -9,7 +9,7 @@ extern crate num; extern crate num_bigint; use num_bigint::{BigInt, ToBigInt}; -use babyjubjub_rs::{utils, Point}; +use babyjubjub_rs::Point; fn criterion_benchmark(c: &mut Criterion) { let p: Point = Point { @@ -44,7 +44,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); let sk = babyjubjub_rs::new_key(); - let pk = sk.public().unwrap(); + let pk = sk.public(); let msg = 5.to_bigint().unwrap(); c.bench_function("sign", |b| b.iter(|| sk.sign(msg.clone()))); let sig = sk.sign(msg.clone()).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 3a365a1..23d105b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,11 @@ pub type Fr = poseidon_rs::Fr; // alias use arrayref::array_ref; -#[cfg(feature = "default")] -use blake_hash::Digest; +#[cfg(not(feature = "aarch64"))] +use blake_hash::Digest; // compatible version with Blake used at circomlib + +#[cfg(feature = "aarch64")] +extern crate blake; // compatible version with Blake used at circomlib use std::cmp::min; @@ -220,14 +223,14 @@ pub fn decompress_point(bb: [u8; 32]) -> Result { Ok(Point { x: x_fr, y: y_fr }) } -#[cfg(feature = "default")] +#[cfg(not(feature = "aarch64"))] fn blh(b: &[u8]) -> Vec { - let hash = blake_hash::Blake512::digest(&b); + let hash = blake_hash::Blake512::digest(b); hash.to_vec() } #[cfg(feature = "aarch64")] -fn blh(b: &Vec) -> Vec { +fn blh(b: &[u8]) -> Vec { let mut hash = [0; 64]; blake::hash(512, b, &mut hash).unwrap(); hash.to_vec() @@ -285,9 +288,11 @@ impl PrivateKey { // let mut h = hasher.finalize(); // compatible with circomlib implementation - let hash: Vec = blh(&self.key.to_vec()); + let hash: Vec = blh(&self.key); let mut h: Vec = hash[..32].to_vec(); + // prune buffer following RFC 8032 + // https://tools.ietf.org/html/rfc8032#page-13 h[0] &= 0xF8; h[31] &= 0x7F; h[31] |= 0x40; @@ -308,7 +313,7 @@ impl PrivateKey { // let mut hasher = Blake2b::new(); // hasher.update(sk_bytes); // let mut h = hasher.finalize(); // h: hash(sk), s: h[32:64] - let mut h: Vec = blh(&self.key.to_vec()); + let mut h: Vec = blh(&self.key); let (_, msg_bytes) = msg.to_bytes_le(); let mut msg32: [u8; 32] = [0; 32]; @@ -346,7 +351,7 @@ impl PrivateKey { let r = B8.mul_scalar(&k); // h = H(x, r, m) - let pk = &self.public(); + let pk = self.public(); let h = schnorr_hash(&pk, m, &r)?; // s= k+x·h @@ -409,8 +414,8 @@ pub fn verify(pk: Point, sig: Signature, msg: BigInt) -> bool { #[cfg(test)] mod tests { use super::*; - use rand::Rng; use ::hex; + use rand::Rng; #[test] fn test_add_same_point() { @@ -590,14 +595,16 @@ mod tests { #[test] fn test_point_decompress0() { - let y_bytes_raw = hex::decode("b5328f8791d48f20bec6e481d91c7ada235f1facf22547901c18656b6c3e042f") - .unwrap(); + let y_bytes_raw = + hex::decode("b5328f8791d48f20bec6e481d91c7ada235f1facf22547901c18656b6c3e042f") + .unwrap(); let mut y_bytes: [u8; 32] = [0; 32]; y_bytes.copy_from_slice(&y_bytes_raw); let p = decompress_point(y_bytes).unwrap(); - let expected_px_raw = hex::decode("b86cc8d9c97daef0afe1a4753c54fb2d8a530dc74c7eee4e72b3fdf2496d2113") - .unwrap(); + let expected_px_raw = + hex::decode("b86cc8d9c97daef0afe1a4753c54fb2d8a530dc74c7eee4e72b3fdf2496d2113") + .unwrap(); let mut e_px_bytes: [u8; 32] = [0; 32]; e_px_bytes.copy_from_slice(&expected_px_raw); let expected_px: Fr = @@ -607,14 +614,16 @@ mod tests { #[test] fn test_point_decompress1() { - let y_bytes_raw = hex::decode("70552d3ff548e09266ded29b33ce75139672b062b02aa66bb0d9247ffecf1d0b") - .unwrap(); + let y_bytes_raw = + hex::decode("70552d3ff548e09266ded29b33ce75139672b062b02aa66bb0d9247ffecf1d0b") + .unwrap(); let mut y_bytes: [u8; 32] = [0; 32]; y_bytes.copy_from_slice(&y_bytes_raw); let p = decompress_point(y_bytes).unwrap(); - let expected_px_raw = hex::decode("30f1635ba7d56f9cb32c3ffbe6dca508a68c7f43936af11a23c785ce98cb3404") - .unwrap(); + let expected_px_raw = + hex::decode("30f1635ba7d56f9cb32c3ffbe6dca508a68c7f43936af11a23c785ce98cb3404") + .unwrap(); let mut e_px_bytes: [u8; 32] = [0; 32]; e_px_bytes.copy_from_slice(&expected_px_raw); let expected_px: Fr = diff --git a/src/utils.rs b/src/utils.rs index acbecfb..d40e739 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -115,10 +115,10 @@ pub fn modsqrt(a: &BigInt, q: &BigInt) -> Result { let zero: BigInt = Zero::zero(); let one: BigInt = One::one(); - if legendre_symbol(&a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() { + if legendre_symbol(a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() { return Err("not a mod p square".to_string()); } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { - let r = a.modpow(&((q + one) / 4), &q); + let r = a.modpow(&((q + one) / 4), q); return Ok(r); } @@ -168,10 +168,10 @@ pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> Result { let zero: BigInt = Zero::zero(); let one: BigInt = One::one(); - if legendre_symbol(&a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() { + if legendre_symbol(a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() { return Err("not a mod p square".to_string()); } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { - let r = a.modpow(&((q + one) / 4), &q); + let r = a.modpow(&((q + one) / 4), q); return Ok(r); } @@ -215,7 +215,7 @@ pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> Result { pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 { // returns 1 if has a square root modulo q let one: BigInt = One::one(); - let ls: BigInt = a.modpow(&((q - &one) >> 1), &q); + let ls: BigInt = a.modpow(&((q - &one) >> 1), q); if ls == q - one { return -1; }