You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

265 lines
7.4 KiB

  1. // BabyJubJub elliptic curve implementation in Rust.
  2. // For LICENSE check https://github.com/arnaucube/babyjubjub-rs
  3. extern crate num;
  4. extern crate num_bigint;
  5. extern crate num_traits;
  6. use num_bigint::{BigInt, ToBigInt};
  7. use num_traits::{One, Zero};
  8. pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt {
  9. ((a % m) + m) % m
  10. }
  11. pub fn modinv(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
  12. let big_zero: BigInt = Zero::zero();
  13. if a == &big_zero {
  14. return Err("no mod inv of Zero".to_string());
  15. }
  16. let mut mn = (q.clone(), a.clone());
  17. let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one());
  18. while mn.1 != big_zero {
  19. xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1);
  20. mn = (mn.1.clone(), modulus(&mn.0, &mn.1));
  21. }
  22. while xy.0 < Zero::zero() {
  23. xy.0 = modulus(&xy.0, q);
  24. }
  25. Ok(xy.0)
  26. }
  27. /*
  28. pub fn modinv_v2(a0: &BigInt, m0: &BigInt) -> BigInt {
  29. if m0 == &One::one() {
  30. return One::one();
  31. }
  32. let (mut a, mut m, mut x0, mut inv): (BigInt, BigInt, BigInt, BigInt) =
  33. (a0.clone(), m0.clone(), Zero::zero(), One::one());
  34. while a > One::one() {
  35. inv = inv - (&a / m.clone()) * x0.clone();
  36. a = a % m.clone();
  37. std::mem::swap(&mut a, &mut m);
  38. std::mem::swap(&mut x0, &mut inv);
  39. }
  40. if inv < Zero::zero() {
  41. inv += m0.clone()
  42. }
  43. inv
  44. }
  45. pub fn modinv_v3(a: &BigInt, q: &BigInt) -> BigInt {
  46. let mut aa: BigInt = a.clone();
  47. let mut qq: BigInt = q.clone();
  48. if qq < Zero::zero() {
  49. qq = -qq;
  50. }
  51. if aa < Zero::zero() {
  52. aa = -aa;
  53. }
  54. let d = num::Integer::gcd(&aa, &qq);
  55. if d != One::one() {
  56. println!("ERR no mod_inv");
  57. }
  58. let res: BigInt;
  59. if d < Zero::zero() {
  60. res = d + qq;
  61. } else {
  62. res = d;
  63. }
  64. res
  65. }
  66. pub fn modinv_v4(x: &BigInt, q: &BigInt) -> BigInt {
  67. let (gcd, inverse, _) = extended_gcd(x.clone(), q.clone());
  68. let one: BigInt = One::one();
  69. if gcd == one {
  70. modulus(&inverse, q)
  71. } else {
  72. panic!("error: gcd!=one")
  73. }
  74. }
  75. pub fn extended_gcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) {
  76. let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
  77. let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
  78. let (mut r, mut old_r) = (b, a);
  79. while r != BigInt::zero() {
  80. let quotient = &old_r / &r;
  81. old_r -= &quotient * &r;
  82. std::mem::swap(&mut old_r, &mut r);
  83. old_s -= &quotient * &s;
  84. std::mem::swap(&mut old_s, &mut s);
  85. old_t -= quotient * &t;
  86. std::mem::swap(&mut old_t, &mut t);
  87. }
  88. let _quotients = (t, s); // == (a, b) / gcd
  89. (old_r, old_s, old_t)
  90. }
  91. */
  92. pub fn concatenate_arrays<T: Clone>(x: &[T], y: &[T]) -> Vec<T> {
  93. x.iter().chain(y).cloned().collect()
  94. }
  95. #[allow(clippy::many_single_char_names)]
  96. pub fn modsqrt(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
  97. // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
  98. //
  99. // This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859
  100. // Also described in https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf
  101. // -> section 6
  102. let zero: BigInt = Zero::zero();
  103. let one: BigInt = One::one();
  104. if legendre_symbol(&a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() {
  105. return Err("not a mod p square".to_string());
  106. } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
  107. let r = a.modpow(&((q + one) / 4), &q);
  108. return Ok(r);
  109. }
  110. let mut s = q - &one;
  111. let mut e: BigInt = Zero::zero();
  112. while &s % 2 == zero {
  113. s >>= 1;
  114. e += &one;
  115. }
  116. let mut n: BigInt = 2.to_bigint().unwrap();
  117. while legendre_symbol(&n, q) != -1 {
  118. n = &n + &one;
  119. }
  120. let mut y = a.modpow(&((&s + &one) >> 1), q);
  121. let mut b = a.modpow(&s, q);
  122. let mut g = n.modpow(&s, q);
  123. let mut r = e;
  124. loop {
  125. let mut t = b.clone();
  126. let mut m: BigInt = Zero::zero();
  127. while t != one {
  128. t = modulus(&(&t * &t), q);
  129. m += &one;
  130. }
  131. if m == zero {
  132. return Ok(y);
  133. }
  134. t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q);
  135. g = g.modpow(&(2.to_bigint().unwrap().modpow(&(r - &m), q)), q);
  136. y = modulus(&(y * t), q);
  137. b = modulus(&(b * &g), q);
  138. r = m.clone();
  139. }
  140. }
  141. #[allow(dead_code)]
  142. #[allow(clippy::many_single_char_names)]
  143. pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
  144. // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
  145. //
  146. // This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py
  147. let zero: BigInt = Zero::zero();
  148. let one: BigInt = One::one();
  149. if legendre_symbol(&a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() {
  150. return Err("not a mod p square".to_string());
  151. } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
  152. let r = a.modpow(&((q + one) / 4), &q);
  153. return Ok(r);
  154. }
  155. let mut p = q - &one;
  156. let mut s: BigInt = Zero::zero();
  157. while &p % 2.to_bigint().unwrap() == zero {
  158. s += &one;
  159. p >>= 1;
  160. }
  161. let mut z: BigInt = One::one();
  162. while legendre_symbol(&z, q) != -1 {
  163. z = &z + &one;
  164. }
  165. let mut c = z.modpow(&p, q);
  166. let mut x = a.modpow(&((&p + &one) >> 1), q);
  167. let mut t = a.modpow(&p, q);
  168. let mut m = s;
  169. while t != one {
  170. let mut i: BigInt = One::one();
  171. let mut e: BigInt = 2.to_bigint().unwrap();
  172. while i < m {
  173. if t.modpow(&e, q) == one {
  174. break;
  175. }
  176. e *= 2.to_bigint().unwrap();
  177. i += &one;
  178. }
  179. let b = c.modpow(&(2.to_bigint().unwrap().modpow(&(&m - &i - 1), q)), q);
  180. x = modulus(&(x * &b), q);
  181. t = modulus(&(t * &b * &b), q);
  182. c = modulus(&(&b * &b), q);
  183. m = i.clone();
  184. }
  185. Ok(x)
  186. }
  187. pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 {
  188. // returns 1 if has a square root modulo q
  189. let one: BigInt = One::one();
  190. let ls: BigInt = a.modpow(&((q - &one) >> 1), &q);
  191. if ls == q - one {
  192. return -1;
  193. }
  194. 1
  195. }
  196. #[cfg(test)]
  197. mod tests {
  198. use super::*;
  199. #[test]
  200. fn test_mod_inverse() {
  201. let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap();
  202. let b = BigInt::parse_bytes(b"12345678", 10).unwrap();
  203. assert_eq!(
  204. modinv(&a, &b).unwrap(),
  205. BigInt::parse_bytes(b"641883", 10).unwrap()
  206. );
  207. }
  208. #[test]
  209. fn test_sqrtmod() {
  210. let a = BigInt::parse_bytes(
  211. b"6536923810004159332831702809452452174451353762940761092345538667656658715568",
  212. 10,
  213. )
  214. .unwrap();
  215. let q = BigInt::parse_bytes(
  216. b"7237005577332262213973186563042994240857116359379907606001950938285454250989",
  217. 10,
  218. )
  219. .unwrap();
  220. assert_eq!(
  221. (modsqrt(&a, &q).unwrap()).to_string(),
  222. "5464794816676661649783249706827271879994893912039750480019443499440603127256"
  223. );
  224. assert_eq!(
  225. (modsqrt_v2(&a, &q).unwrap()).to_string(),
  226. "5464794816676661649783249706827271879994893912039750480019443499440603127256"
  227. );
  228. }
  229. }