You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

261 lines
7.3 KiB

  1. // BabyJubJub elliptic curve implementation in Rust.
  2. // For LICENSE check https://github.com/arnaucube/babyjubjub-rs
  3. use num_bigint::{BigInt, ToBigInt};
  4. use num_traits::{One, Zero};
  5. pub fn modulus(a: &BigInt, m: &BigInt) -> BigInt {
  6. ((a % m) + m) % m
  7. }
  8. pub fn modinv(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
  9. let big_zero: BigInt = Zero::zero();
  10. if a == &big_zero {
  11. return Err("no mod inv of Zero".to_string());
  12. }
  13. let mut mn = (q.clone(), a.clone());
  14. let mut xy: (BigInt, BigInt) = (Zero::zero(), One::one());
  15. while mn.1 != big_zero {
  16. xy = (xy.1.clone(), xy.0 - (mn.0.clone() / mn.1.clone()) * xy.1);
  17. mn = (mn.1.clone(), modulus(&mn.0, &mn.1));
  18. }
  19. while xy.0 < Zero::zero() {
  20. xy.0 = modulus(&xy.0, q);
  21. }
  22. Ok(xy.0)
  23. }
  24. /*
  25. pub fn modinv_v2(a0: &BigInt, m0: &BigInt) -> BigInt {
  26. if m0 == &One::one() {
  27. return One::one();
  28. }
  29. let (mut a, mut m, mut x0, mut inv): (BigInt, BigInt, BigInt, BigInt) =
  30. (a0.clone(), m0.clone(), Zero::zero(), One::one());
  31. while a > One::one() {
  32. inv = inv - (&a / m.clone()) * x0.clone();
  33. a = a % m.clone();
  34. std::mem::swap(&mut a, &mut m);
  35. std::mem::swap(&mut x0, &mut inv);
  36. }
  37. if inv < Zero::zero() {
  38. inv += m0.clone()
  39. }
  40. inv
  41. }
  42. pub fn modinv_v3(a: &BigInt, q: &BigInt) -> BigInt {
  43. let mut aa: BigInt = a.clone();
  44. let mut qq: BigInt = q.clone();
  45. if qq < Zero::zero() {
  46. qq = -qq;
  47. }
  48. if aa < Zero::zero() {
  49. aa = -aa;
  50. }
  51. let d = num::Integer::gcd(&aa, &qq);
  52. if d != One::one() {
  53. println!("ERR no mod_inv");
  54. }
  55. let res: BigInt;
  56. if d < Zero::zero() {
  57. res = d + qq;
  58. } else {
  59. res = d;
  60. }
  61. res
  62. }
  63. pub fn modinv_v4(x: &BigInt, q: &BigInt) -> BigInt {
  64. let (gcd, inverse, _) = extended_gcd(x.clone(), q.clone());
  65. let one: BigInt = One::one();
  66. if gcd == one {
  67. modulus(&inverse, q)
  68. } else {
  69. panic!("error: gcd!=one")
  70. }
  71. }
  72. pub fn extended_gcd(a: BigInt, b: BigInt) -> (BigInt, BigInt, BigInt) {
  73. let (mut s, mut old_s) = (BigInt::zero(), BigInt::one());
  74. let (mut t, mut old_t) = (BigInt::one(), BigInt::zero());
  75. let (mut r, mut old_r) = (b, a);
  76. while r != BigInt::zero() {
  77. let quotient = &old_r / &r;
  78. old_r -= &quotient * &r;
  79. std::mem::swap(&mut old_r, &mut r);
  80. old_s -= &quotient * &s;
  81. std::mem::swap(&mut old_s, &mut s);
  82. old_t -= quotient * &t;
  83. std::mem::swap(&mut old_t, &mut t);
  84. }
  85. let _quotients = (t, s); // == (a, b) / gcd
  86. (old_r, old_s, old_t)
  87. }
  88. */
  89. pub fn concatenate_arrays<T: Clone>(x: &[T], y: &[T]) -> Vec<T> {
  90. x.iter().chain(y).cloned().collect()
  91. }
  92. #[allow(clippy::many_single_char_names)]
  93. pub fn modsqrt(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
  94. // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
  95. //
  96. // This implementation is following the Go lang core implementation https://golang.org/src/math/big/int.go?s=23173:23210#L859
  97. // Also described in https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf
  98. // -> section 6
  99. let zero: BigInt = Zero::zero();
  100. let one: BigInt = One::one();
  101. if legendre_symbol(a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() {
  102. return Err("not a mod p square".to_string());
  103. } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
  104. let r = a.modpow(&((q + one) / 4), q);
  105. return Ok(r);
  106. }
  107. let mut s = q - &one;
  108. let mut e: BigInt = Zero::zero();
  109. while &s % 2 == zero {
  110. s >>= 1;
  111. e += &one;
  112. }
  113. let mut n: BigInt = 2.to_bigint().unwrap();
  114. while legendre_symbol(&n, q) != -1 {
  115. n = &n + &one;
  116. }
  117. let mut y = a.modpow(&((&s + &one) >> 1), q);
  118. let mut b = a.modpow(&s, q);
  119. let mut g = n.modpow(&s, q);
  120. let mut r = e;
  121. loop {
  122. let mut t = b.clone();
  123. let mut m: BigInt = Zero::zero();
  124. while t != one {
  125. t = modulus(&(&t * &t), q);
  126. m += &one;
  127. }
  128. if m == zero {
  129. return Ok(y);
  130. }
  131. t = g.modpow(&(2.to_bigint().unwrap().modpow(&(&r - &m - 1), q)), q);
  132. g = g.modpow(&(2.to_bigint().unwrap().modpow(&(r - &m), q)), q);
  133. y = modulus(&(y * t), q);
  134. b = modulus(&(b * &g), q);
  135. r = m.clone();
  136. }
  137. }
  138. #[allow(dead_code)]
  139. #[allow(clippy::many_single_char_names)]
  140. pub fn modsqrt_v2(a: &BigInt, q: &BigInt) -> Result<BigInt, String> {
  141. // Tonelli-Shanks Algorithm (https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm)
  142. //
  143. // This implementation is following this Python implementation by Dusk https://github.com/dusk-network/dusk-zerocaf/blob/master/tools/tonelli.py
  144. let zero: BigInt = Zero::zero();
  145. let one: BigInt = One::one();
  146. if legendre_symbol(a, q) != 1 || a == &zero || q == &2.to_bigint().unwrap() {
  147. return Err("not a mod p square".to_string());
  148. } else if q % 4.to_bigint().unwrap() == 3.to_bigint().unwrap() {
  149. let r = a.modpow(&((q + one) / 4), q);
  150. return Ok(r);
  151. }
  152. let mut p = q - &one;
  153. let mut s: BigInt = Zero::zero();
  154. while &p % 2.to_bigint().unwrap() == zero {
  155. s += &one;
  156. p >>= 1;
  157. }
  158. let mut z: BigInt = One::one();
  159. while legendre_symbol(&z, q) != -1 {
  160. z = &z + &one;
  161. }
  162. let mut c = z.modpow(&p, q);
  163. let mut x = a.modpow(&((&p + &one) >> 1), q);
  164. let mut t = a.modpow(&p, q);
  165. let mut m = s;
  166. while t != one {
  167. let mut i: BigInt = One::one();
  168. let mut e: BigInt = 2.to_bigint().unwrap();
  169. while i < m {
  170. if t.modpow(&e, q) == one {
  171. break;
  172. }
  173. e *= 2.to_bigint().unwrap();
  174. i += &one;
  175. }
  176. let b = c.modpow(&(2.to_bigint().unwrap().modpow(&(&m - &i - 1), q)), q);
  177. x = modulus(&(x * &b), q);
  178. t = modulus(&(t * &b * &b), q);
  179. c = modulus(&(&b * &b), q);
  180. m = i.clone();
  181. }
  182. Ok(x)
  183. }
  184. pub fn legendre_symbol(a: &BigInt, q: &BigInt) -> i32 {
  185. // returns 1 if has a square root modulo q
  186. let one: BigInt = One::one();
  187. let ls: BigInt = a.modpow(&((q - &one) >> 1), q);
  188. if ls == q - one {
  189. return -1;
  190. }
  191. 1
  192. }
  193. #[cfg(test)]
  194. mod tests {
  195. use super::*;
  196. #[test]
  197. fn test_mod_inverse() {
  198. let a = BigInt::parse_bytes(b"123456789123456789123456789123456789123456789", 10).unwrap();
  199. let b = BigInt::parse_bytes(b"12345678", 10).unwrap();
  200. assert_eq!(
  201. modinv(&a, &b).unwrap(),
  202. BigInt::parse_bytes(b"641883", 10).unwrap()
  203. );
  204. }
  205. #[test]
  206. fn test_sqrtmod() {
  207. let a = BigInt::parse_bytes(
  208. b"6536923810004159332831702809452452174451353762940761092345538667656658715568",
  209. 10,
  210. )
  211. .unwrap();
  212. let q = BigInt::parse_bytes(
  213. b"7237005577332262213973186563042994240857116359379907606001950938285454250989",
  214. 10,
  215. )
  216. .unwrap();
  217. assert_eq!(
  218. (modsqrt(&a, &q).unwrap()).to_string(),
  219. "5464794816676661649783249706827271879994893912039750480019443499440603127256"
  220. );
  221. assert_eq!(
  222. (modsqrt_v2(&a, &q).unwrap()).to_string(),
  223. "5464794816676661649783249706827271879994893912039750480019443499440603127256"
  224. );
  225. }
  226. }