263 lines
7.2 KiB

extern crate num;
extern crate num_bigint;
extern crate num_traits;
use num_bigint::{BigInt, ToBigInt};
use num_traits::{One, Zero};
pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt {
((a % m) + m) % m
}
pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt {
let mut mn = (q.clone(), a.clone());
let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one());
let big_zero: BigInt = Zero::zero();
while mn.1 != big_zero {
xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1);
mn = (mn.1.clone(), modulus(&mn.0, &mn.1));
}
while xy.0 < Zero::zero() {
xy.0 = modulus(&xy.0, q);
}
xy.0
}
/*
pub fn modinv_v2(a0: &BigInt, m0: &BigInt) -> BigInt {
if m0 == &One::one() {
return One::one();
}
let (mut a, mut m, mut x0, mut inv): (BigInt, BigInt, BigInt, BigInt) =
(a0.clone(), m0.clone(), Zero::zero(), One::one());
while a > One::one() {
inv = inv - (&a / m.clone()) * x0.clone();
a = a % m.clone();
std::mem::swap(&mut a, &mut m);
std::mem::swap(&mut x0, &mut inv);
}
if inv < Zero::zero() {
inv += m0.clone()
}
inv
}
pub fn modinv_v3(a: &BigInt, q: &BigInt) -> BigInt {
let mut aa: BigInt = a.clone();
let mut qq: BigInt = q.clone();
if qq < Zero::zero() {
qq = -qq;
}
if aa < Zero::zero() {
aa = -aa;
}
let d = num::Integer::gcd(&aa, &qq);
if d != One::one() {
println!("ERR no mod_inv");
}
let res: BigInt;
if d < Zero::zero() {
res = d + qq;
} else {
res = d;
}
res
}
pub fn modinv_v4(x: &BigInt, q: &BigInt) -> BigInt {
let (gcd, inverse, _) = extended_gcd(x.clone(), q.clone());
let one: BigInt = One::one();
if gcd == one {
modulus(&inverse, q)
} else {
panic!("error: gcd!=one")
}
}
pub fn extended_gcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) {
let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
let (mut r, mut old_r) = (b, a);
while r != BigInt::zero() {
let quotient = &old_r / &r;
old_r -= &quotient * &r;
std::mem::swap(&mut old_r, &mut r);
old_s -= &quotient * &s;
std::mem::swap(&mut old_s, &mut s);
old_t -= quotient * &t;
std::mem::swap(&mut old_t, &mut t);
}
let _quotients = (t, s); // == (a, b) / gcd
(old_r, old_s, old_t)
}
*/
pub fn concatenate_arrays<T: Clone>(x: &[T], y: &[T]) -> Vec<T> {
x.iter().chain(y).cloned().collect()
}
pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt {
// Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
//
// This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859
// Also described in https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf
// -> section 6
let zero: BigInt = Zero::zero();
let one: BigInt = One::one();
if legendre_symbol(&a, q) != 1 {
// not a mod p square
return zero;
} else if a == &zero {
return zero;
} else if q == &2.to_bigint().unwrap() {
return zero;
} else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
let r = a.modpow(&((q + one) / 4), &q);
return r;
}
let mut s = q - &one;
let mut e: BigInt = Zero::zero();
while &s % 2 == zero {
s = s >> 1;
e = e + &one;
}
let mut n: BigInt = 2.to_bigint().unwrap();
while legendre_symbol(&n, q) != -1 {
n = &n + &one;
}
let mut y = a.modpow(&((&s + &one) >> 1), q);
let mut b = a.modpow(&s, q);
let mut g = n.modpow(&s, q);
let mut r = e;
loop {
let mut t = b.clone();
let mut m: BigInt = Zero::zero();
while &t != &one {
t = modulus(&(&t * &t), q);
m = m + &one;
}
if m == zero {
return y.clone();
}
t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q);
g = g.modpow(&(2.to_bigint().unwrap().modpow(&(r - &m), q)), q);
y = modulus(&(y * t), q);
b = modulus(&(b * &g), q);
r = m.clone();
}
}
#[allow(dead_code)]
pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt {
// Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
//
// This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py
let zero: BigInt = Zero::zero();
let one: BigInt = One::one();
if legendre_symbol(&a, q) != 1 {
// not a mod p square
return zero;
} else if a == &zero {
return zero;
} else if q == &2.to_bigint().unwrap() {
return zero;
} else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
let r = a.modpow(&((q + one) / 4), &q);
return r;
}
let mut p = q - &one;
let mut s: BigInt = Zero::zero();
while &p % 2.to_bigint().unwrap() == zero {
s = s + &one;
p = p >> 1;
}
let mut z: BigInt = One::one();
while legendre_symbol(&z, q) != -1 {
z = &z + &one;
}
let mut c = z.modpow(&p, q);
let mut x = a.modpow(&((&p + &one) >> 1), q);
let mut t = a.modpow(&p, q);
let mut m = s;
while &t != &one {
let mut i: BigInt = One::one();
let mut e: BigInt = 2.to_bigint().unwrap();
while i < m {
if t.modpow(&e, q) == one {
break;
}
e = e * 2.to_bigint().unwrap();
i = i + &one;
}
let b = c.modpow(&(2.to_bigint().unwrap().modpow(&(&m - &i - 1), q)), q);
x = modulus(&(x * &b), q);
t = modulus(&(t * &b * &b), q);
c = modulus(&(&b * &b), q);
m = i.clone();
}
return x;
}
pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 {
// returns 1 if has a square root modulo q
let one: BigInt = One::one();
let ls: BigInt = a.modpow(&((q - &one) >> 1), &q);
if &(ls) == &(q - one) {
return -1;
}
1
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mod_inverse() {
let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap();
let b = BigInt::parse_bytes(b"12345678", 10).unwrap();
assert_eq!(modinv(&a, &b), BigInt::parse_bytes(b"641883", 10).unwrap());
}
#[test]
fn test_sqrtmod() {
let a = BigInt::parse_bytes(
b"6536923810004159332831702809452452174451353762940761092345538667656658715568",
10,
)
.unwrap();
let q = BigInt::parse_bytes(
b"7237005577332262213973186563042994240857116359379907606001950938285454250989",
10,
)
.unwrap();
assert_eq!(
(modsqrt(&a, &q)).to_string(),
"5464794816676661649783249706827271879994893912039750480019443499440603127256"
);
assert_eq!(
(modsqrt_v2(&a, &q)).to_string(),
"5464794816676661649783249706827271879994893912039750480019443499440603127256"
);
}
}