You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
7.2 KiB

  1. extern crate num;
  2. extern crate num_bigint;
  3. extern crate num_traits;
  4. use num_bigint::{BigInt, ToBigInt};
  5. use num_traits::{One, Zero};
  6. pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt {
  7. ((a % m) + m) % m
  8. }
  9. pub fn modinv(a: &BigInt, q: &BigInt) -> BigInt {
  10. let mut mn = (q.clone(), a.clone());
  11. let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one());
  12. let big_zero: BigInt = Zero::zero();
  13. while mn.1 != big_zero {
  14. xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1);
  15. mn = (mn.1.clone(), modulus(&mn.0, &mn.1));
  16. }
  17. while xy.0 < Zero::zero() {
  18. xy.0 = modulus(&xy.0, q);
  19. }
  20. xy.0
  21. }
  22. /*
  23. pub fn modinv_v2(a0: &BigInt, m0: &BigInt) -> BigInt {
  24. if m0 == &One::one() {
  25. return One::one();
  26. }
  27. let (mut a, mut m, mut x0, mut inv): (BigInt, BigInt, BigInt, BigInt) =
  28. (a0.clone(), m0.clone(), Zero::zero(), One::one());
  29. while a > One::one() {
  30. inv = inv - (&a / m.clone()) * x0.clone();
  31. a = a % m.clone();
  32. std::mem::swap(&mut a, &mut m);
  33. std::mem::swap(&mut x0, &mut inv);
  34. }
  35. if inv < Zero::zero() {
  36. inv += m0.clone()
  37. }
  38. inv
  39. }
  40. pub fn modinv_v3(a: &BigInt, q: &BigInt) -> BigInt {
  41. let mut aa: BigInt = a.clone();
  42. let mut qq: BigInt = q.clone();
  43. if qq < Zero::zero() {
  44. qq = -qq;
  45. }
  46. if aa < Zero::zero() {
  47. aa = -aa;
  48. }
  49. let d = num::Integer::gcd(&aa, &qq);
  50. if d != One::one() {
  51. println!("ERR no mod_inv");
  52. }
  53. let res: BigInt;
  54. if d < Zero::zero() {
  55. res = d + qq;
  56. } else {
  57. res = d;
  58. }
  59. res
  60. }
  61. pub fn modinv_v4(x: &BigInt, q: &BigInt) -> BigInt {
  62. let (gcd, inverse, _) = extended_gcd(x.clone(), q.clone());
  63. let one: BigInt = One::one();
  64. if gcd == one {
  65. modulus(&inverse, q)
  66. } else {
  67. panic!("error: gcd!=one")
  68. }
  69. }
  70. pub fn extended_gcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) {
  71. let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
  72. let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
  73. let (mut r, mut old_r) = (b, a);
  74. while r != BigInt::zero() {
  75. let quotient = &old_r / &r;
  76. old_r -= &quotient * &r;
  77. std::mem::swap(&mut old_r, &mut r);
  78. old_s -= &quotient * &s;
  79. std::mem::swap(&mut old_s, &mut s);
  80. old_t -= quotient * &t;
  81. std::mem::swap(&mut old_t, &mut t);
  82. }
  83. let _quotients = (t, s); // == (a, b) / gcd
  84. (old_r, old_s, old_t)
  85. }
  86. */
  87. pub fn concatenate_arrays<T: Clone>(x: &[T], y: &[T]) -> Vec<T> {
  88. x.iter().chain(y).cloned().collect()
  89. }
  90. pub fn modsqrt(a: &BigInt, q: &BigInt) -> BigInt {
  91. // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
  92. //
  93. // This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859
  94. // Also described in https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf
  95. // -> section 6
  96. let zero: BigInt = Zero::zero();
  97. let one: BigInt = One::one();
  98. if legendre_symbol(&a, q) != 1 {
  99. // not a mod p square
  100. return zero;
  101. } else if a == &zero {
  102. return zero;
  103. } else if q == &2.to_bigint().unwrap() {
  104. return zero;
  105. } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
  106. let r = a.modpow(&((q + one) / 4), &q);
  107. return r;
  108. }
  109. let mut s = q - &one;
  110. let mut e: BigInt = Zero::zero();
  111. while &s % 2 == zero {
  112. s = s >> 1;
  113. e = e + &one;
  114. }
  115. let mut n: BigInt = 2.to_bigint().unwrap();
  116. while legendre_symbol(&n, q) != -1 {
  117. n = &n + &one;
  118. }
  119. let mut y = a.modpow(&((&s + &one) >> 1), q);
  120. let mut b = a.modpow(&s, q);
  121. let mut g = n.modpow(&s, q);
  122. let mut r = e;
  123. loop {
  124. let mut t = b.clone();
  125. let mut m: BigInt = Zero::zero();
  126. while &t != &one {
  127. t = modulus(&(&t * &t), q);
  128. m = m + &one;
  129. }
  130. if m == zero {
  131. return y.clone();
  132. }
  133. t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q);
  134. g = g.modpow(&(2.to_bigint().unwrap().modpow(&(r - &m), q)), q);
  135. y = modulus(&(y * t), q);
  136. b = modulus(&(b * &g), q);
  137. r = m.clone();
  138. }
  139. }
  140. #[allow(dead_code)]
  141. pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> BigInt {
  142. // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
  143. //
  144. // This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py
  145. let zero: BigInt = Zero::zero();
  146. let one: BigInt = One::one();
  147. if legendre_symbol(&a, q) != 1 {
  148. // not a mod p square
  149. return zero;
  150. } else if a == &zero {
  151. return zero;
  152. } else if q == &2.to_bigint().unwrap() {
  153. return zero;
  154. } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
  155. let r = a.modpow(&((q + one) / 4), &q);
  156. return r;
  157. }
  158. let mut p = q - &one;
  159. let mut s: BigInt = Zero::zero();
  160. while &p % 2.to_bigint().unwrap() == zero {
  161. s = s + &one;
  162. p = p >> 1;
  163. }
  164. let mut z: BigInt = One::one();
  165. while legendre_symbol(&z, q) != -1 {
  166. z = &z + &one;
  167. }
  168. let mut c = z.modpow(&p, q);
  169. let mut x = a.modpow(&((&p + &one) >> 1), q);
  170. let mut t = a.modpow(&p, q);
  171. let mut m = s;
  172. while &t != &one {
  173. let mut i: BigInt = One::one();
  174. let mut e: BigInt = 2.to_bigint().unwrap();
  175. while i < m {
  176. if t.modpow(&e, q) == one {
  177. break;
  178. }
  179. e = e * 2.to_bigint().unwrap();
  180. i = i + &one;
  181. }
  182. let b = c.modpow(&(2.to_bigint().unwrap().modpow(&(&m - &i - 1), q)), q);
  183. x = modulus(&(x * &b), q);
  184. t = modulus(&(t * &b * &b), q);
  185. c = modulus(&(&b * &b), q);
  186. m = i.clone();
  187. }
  188. return x;
  189. }
  190. pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 {
  191. // returns 1 if has a square root modulo q
  192. let one: BigInt = One::one();
  193. let ls: BigInt = a.modpow(&((q - &one) >> 1), &q);
  194. if &(ls) == &(q - one) {
  195. return -1;
  196. }
  197. 1
  198. }
  199. #[cfg(test)]
  200. mod tests {
  201. use super::*;
  202. #[test]
  203. fn test_mod_inverse() {
  204. let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap();
  205. let b = BigInt::parse_bytes(b"12345678", 10).unwrap();
  206. assert_eq!(modinv(&a, &b), BigInt::parse_bytes(b"641883", 10).unwrap());
  207. }
  208. #[test]
  209. fn test_sqrtmod() {
  210. let a = BigInt::parse_bytes(
  211. b"6536923810004159332831702809452452174451353762940761092345538667656658715568",
  212. 10,
  213. )
  214. .unwrap();
  215. let q = BigInt::parse_bytes(
  216. b"7237005577332262213973186563042994240857116359379907606001950938285454250989",
  217. 10,
  218. )
  219. .unwrap();
  220. assert_eq!(
  221. (modsqrt(&a, &q)).to_string(),
  222. "5464794816676661649783249706827271879994893912039750480019443499440603127256"
  223. );
  224. assert_eq!(
  225. (modsqrt_v2(&a, &q)).to_string(),
  226. "5464794816676661649783249706827271879994893912039750480019443499440603127256"
  227. );
  228. }
  229. }