extern crate num; extern crate num_bigint; extern crate num_traits; use num_bigint::{BigInt, ToBigInt}; use num_traits::{One, Zero}; pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt { ((a % m) + m) % m } pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt { let mut mn = (q.clone(), a.clone()); let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one()); let big_zero: BigInt = Zero::zero(); while mn.1 != big_zero { xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1); mn = (mn.1.clone(), modulus(&mn.0, &mn.1)); } while xy.0 < Zero::zero() { xy.0 = modulus(&xy.0, q); } xy.0 } /* pub fn modinv_v2(a0: &BigInt, m0: &BigInt) -> BigInt { if m0 == &One::one() { return One::one(); } let (mut a, mut m, mut x0, mut inv): (BigInt, BigInt, BigInt, BigInt) = (a0.clone(), m0.clone(), Zero::zero(), One::one()); while a > One::one() { inv = inv - (&a / m.clone()) * x0.clone(); a = a % m.clone(); std::mem::swap(&mut a, &mut m); std::mem::swap(&mut x0, &mut inv); } if inv < Zero::zero() { inv += m0.clone() } inv } pub fn modinv_v3(a: &BigInt, q: &BigInt) -> BigInt { let mut aa: BigInt = a.clone(); let mut qq: BigInt = q.clone(); if qq < Zero::zero() { qq = -qq; } if aa < Zero::zero() { aa = -aa; } let d = num::Integer::gcd(&aa, &qq); if d != One::one() { println!("ERR no mod_inv"); } let res: BigInt; if d < Zero::zero() { res = d + qq; } else { res = d; } res } pub fn modinv_v4(x: &BigInt, q: &BigInt) -> BigInt { let (gcd, inverse, _) = extended_gcd(x.clone(), q.clone()); let one: BigInt = One::one(); if gcd == one { modulus(&inverse, q) } else { panic!("error: gcd!=one") } } pub fn extended_gcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) { let (mut s, mut old_s) = (BigInt::zero(), BigInt::one()); let (mut t, mut old_t) = (BigInt::one(), BigInt::zero()); let (mut r, mut old_r) = (b, a); while r != BigInt::zero() { let quotient = &old_r / &r; old_r -= "ient * &r; std::mem::swap(&mut old_r, &mut r); old_s -= "ient * &s; std::mem::swap(&mut old_s, &mut s); old_t -= quotient * &t; std::mem::swap(&mut old_t, &mut t); } let _quotients = (t, s); // == (a, b) / gcd (old_r, old_s, old_t) } */ pub fn concatenate_arrays(x: &[T], y: &[T]) -> Vec { x.iter().chain(y).cloned().collect() } pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt { // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm) // // This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859 // Also described in https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf // -> section 6 let zero: BigInt = Zero::zero(); let one: BigInt = One::one(); if legendre_symbol(&a, q) != 1 { // not a mod p square return zero; } else if a == &zero { return zero; } else if q == &2.to_bigint().unwrap() { return zero; } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { let r = a.modpow(&((q + one) / 4), &q); return r; } let mut s = q - &one; let mut e: BigInt = Zero::zero(); while &s % 2 == zero { s = s >> 1; e = e + &one; } let mut n: BigInt = 2.to_bigint().unwrap(); while legendre_symbol(&n, q) != -1 { n = &n + &one; } let mut y = a.modpow(&((&s + &one) >> 1), q); let mut b = a.modpow(&s, q); let mut g = n.modpow(&s, q); let mut r = e; loop { let mut t = b.clone(); let mut m: BigInt = Zero::zero(); while &t != &one { t = modulus(&(&t * &t), q); m = m + &one; } if m == zero { return y.clone(); } t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q); g = g.modpow(&(2.to_bigint().unwrap().modpow(&(r - &m), q)), q); y = modulus(&(y * t), q); b = modulus(&(b * &g), q); r = m.clone(); } } #[allow(dead_code)] pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt { // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm) // // This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py let zero: BigInt = Zero::zero(); let one: BigInt = One::one(); if legendre_symbol(&a, q) != 1 { // not a mod p square return zero; } else if a == &zero { return zero; } else if q == &2.to_bigint().unwrap() { return zero; } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() { let r = a.modpow(&((q + one) / 4), &q); return r; } let mut p = q - &one; let mut s: BigInt = Zero::zero(); while &p % 2.to_bigint().unwrap() == zero { s = s + &one; p = p >> 1; } let mut z: BigInt = One::one(); while legendre_symbol(&z, q) != -1 { z = &z + &one; } let mut c = z.modpow(&p, q); let mut x = a.modpow(&((&p + &one) >> 1), q); let mut t = a.modpow(&p, q); let mut m = s; while &t != &one { let mut i: BigInt = One::one(); let mut e: BigInt = 2.to_bigint().unwrap(); while i < m { if t.modpow(&e, q) == one { break; } e = e * 2.to_bigint().unwrap(); i = i + &one; } let b = c.modpow(&(2.to_bigint().unwrap().modpow(&(&m - &i - 1), q)), q); x = modulus(&(x * &b), q); t = modulus(&(t * &b * &b), q); c = modulus(&(&b * &b), q); m = i.clone(); } return x; } pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 { // returns 1 if has a square root modulo q let one: BigInt = One::one(); let ls: BigInt = a.modpow(&((q - &one) >> 1), &q); if &(ls) == &(q - one) { return -1; } 1 } #[cfg(test)] mod tests { use super::*; #[test] fn test_mod_inverse() { let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap(); let b = BigInt::parse_bytes(b"12345678", 10).unwrap(); assert_eq!(modinv(&a, &b), BigInt::parse_bytes(b"641883", 10).unwrap()); } #[test] fn test_sqrtmod() { let a = BigInt::parse_bytes( b"6536923810004159332831702809452452174451353762940761092345538667656658715568", 10, ) .unwrap(); let q = BigInt::parse_bytes( b"7237005577332262213973186563042994240857116359379907606001950938285454250989", 10, ) .unwrap(); assert_eq!( (modsqrt(&a, &q)).to_string(), "5464794816676661649783249706827271879994893912039750480019443499440603127256" ); assert_eq!( (modsqrt_v2(&a, &q)).to_string(), "5464794816676661649783249706827271879994893912039750480019443499440603127256" ); } }