You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

474 lines
12 KiB

  1. // Copyright 2020 ConsenSys AG
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Code generated by goff (v0.2.0) DO NOT EDIT
  15. // Package ff contains field arithmetic operations
  16. package ff
  17. import (
  18. "crypto/rand"
  19. "math/big"
  20. "math/bits"
  21. mrand "math/rand"
  22. "testing"
  23. )
  24. func TestELEMENTCorrectnessAgainstBigInt(t *testing.T) {
  25. modulus, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10)
  26. cmpEandB := func(e *Element, b *big.Int, name string) {
  27. var _e big.Int
  28. if e.FromMont().ToBigInt(&_e).Cmp(b) != 0 {
  29. t.Fatal(name, "failed")
  30. }
  31. }
  32. var modulusMinusOne, one big.Int
  33. one.SetUint64(1)
  34. modulusMinusOne.Sub(modulus, &one)
  35. var n int
  36. if testing.Short() {
  37. n = 10
  38. } else {
  39. n = 500
  40. }
  41. for i := 0; i < n; i++ {
  42. // sample 2 random big int
  43. b1, _ := rand.Int(rand.Reader, modulus)
  44. b2, _ := rand.Int(rand.Reader, modulus)
  45. rExp := mrand.Uint64()
  46. // adding edge cases
  47. // TODO need more edge cases
  48. switch i {
  49. case 0:
  50. rExp = 0
  51. b1.SetUint64(0)
  52. case 1:
  53. b2.SetUint64(0)
  54. case 2:
  55. b1.SetUint64(0)
  56. b2.SetUint64(0)
  57. case 3:
  58. rExp = 0
  59. case 4:
  60. rExp = 1
  61. case 5:
  62. rExp = ^uint64(0) // max uint
  63. case 6:
  64. rExp = 2
  65. b1.Set(&modulusMinusOne)
  66. case 7:
  67. b2.Set(&modulusMinusOne)
  68. case 8:
  69. b1.Set(&modulusMinusOne)
  70. b2.Set(&modulusMinusOne)
  71. }
  72. rbExp := new(big.Int).SetUint64(rExp)
  73. var bMul, bAdd, bSub, bDiv, bNeg, bLsh, bInv, bExp, bExp2, bSquare big.Int
  74. // e1 = mont(b1), e2 = mont(b2)
  75. var e1, e2, eMul, eAdd, eSub, eDiv, eNeg, eLsh, eInv, eExp, eSquare, eMulAssign, eSubAssign, eAddAssign Element
  76. e1.SetBigInt(b1)
  77. e2.SetBigInt(b2)
  78. // (e1*e2).FromMont() === b1*b2 mod q ... etc
  79. eSquare.Square(&e1)
  80. eMul.Mul(&e1, &e2)
  81. eMulAssign.Set(&e1)
  82. eMulAssign.MulAssign(&e2)
  83. eAdd.Add(&e1, &e2)
  84. eAddAssign.Set(&e1)
  85. eAddAssign.AddAssign(&e2)
  86. eSub.Sub(&e1, &e2)
  87. eSubAssign.Set(&e1)
  88. eSubAssign.SubAssign(&e2)
  89. eDiv.Div(&e1, &e2)
  90. eNeg.Neg(&e1)
  91. eInv.Inverse(&e1)
  92. eExp.Exp(e1, rExp)
  93. eLsh.Double(&e1)
  94. // same operations with big int
  95. bAdd.Add(b1, b2).Mod(&bAdd, modulus)
  96. bMul.Mul(b1, b2).Mod(&bMul, modulus)
  97. bSquare.Mul(b1, b1).Mod(&bSquare, modulus)
  98. bSub.Sub(b1, b2).Mod(&bSub, modulus)
  99. bDiv.ModInverse(b2, modulus)
  100. bDiv.Mul(&bDiv, b1).
  101. Mod(&bDiv, modulus)
  102. bNeg.Neg(b1).Mod(&bNeg, modulus)
  103. bInv.ModInverse(b1, modulus)
  104. bExp.Exp(b1, rbExp, modulus)
  105. bLsh.Lsh(b1, 1).Mod(&bLsh, modulus)
  106. cmpEandB(&eSquare, &bSquare, "Square")
  107. cmpEandB(&eMul, &bMul, "Mul")
  108. cmpEandB(&eMulAssign, &bMul, "MulAssign")
  109. cmpEandB(&eAdd, &bAdd, "Add")
  110. cmpEandB(&eAddAssign, &bAdd, "AddAssign")
  111. cmpEandB(&eSub, &bSub, "Sub")
  112. cmpEandB(&eSubAssign, &bSub, "SubAssign")
  113. cmpEandB(&eDiv, &bDiv, "Div")
  114. cmpEandB(&eNeg, &bNeg, "Neg")
  115. cmpEandB(&eInv, &bInv, "Inv")
  116. cmpEandB(&eExp, &bExp, "Exp")
  117. cmpEandB(&eLsh, &bLsh, "Lsh")
  118. // legendre symbol
  119. if e1.Legendre() != big.Jacobi(b1, modulus) {
  120. t.Fatal("legendre symbol computation failed")
  121. }
  122. if e2.Legendre() != big.Jacobi(b2, modulus) {
  123. t.Fatal("legendre symbol computation failed")
  124. }
  125. // these are slow, killing circle ci
  126. if n <= 5 {
  127. // sqrt
  128. var eSqrt, eExp2 Element
  129. var bSqrt big.Int
  130. bSqrt.ModSqrt(b1, modulus)
  131. eSqrt.Sqrt(&e1)
  132. cmpEandB(&eSqrt, &bSqrt, "Sqrt")
  133. bits := b2.Bits()
  134. exponent := make([]uint64, len(bits))
  135. for k := 0; k < len(bits); k++ {
  136. exponent[k] = uint64(bits[k])
  137. }
  138. eExp2.Exp(e1, exponent...)
  139. bExp2.Exp(b1, b2, modulus)
  140. cmpEandB(&eExp2, &bExp2, "Exp multi words")
  141. }
  142. }
  143. }
  144. func TestELEMENTIsRandom(t *testing.T) {
  145. for i := 0; i < 50; i++ {
  146. var x, y Element
  147. x.SetRandom()
  148. y.SetRandom()
  149. if x.Equal(&y) {
  150. t.Fatal("2 random numbers are unlikely to be equal")
  151. }
  152. }
  153. }
  154. // -------------------------------------------------------------------------------------------------
  155. // benchmarks
  156. // most benchmarks are rudimentary and should sample a large number of random inputs
  157. // or be run multiple times to ensure it didn't measure the fastest path of the function
  158. var benchResElement Element
  159. func BenchmarkInverseELEMENT(b *testing.B) {
  160. var x Element
  161. x.SetRandom()
  162. benchResElement.SetRandom()
  163. b.ResetTimer()
  164. for i := 0; i < b.N; i++ {
  165. benchResElement.Inverse(&x)
  166. }
  167. }
  168. func BenchmarkExpELEMENT(b *testing.B) {
  169. var x Element
  170. x.SetRandom()
  171. benchResElement.SetRandom()
  172. b.ResetTimer()
  173. for i := 0; i < b.N; i++ {
  174. benchResElement.Exp(x, mrand.Uint64())
  175. }
  176. }
  177. func BenchmarkDoubleELEMENT(b *testing.B) {
  178. benchResElement.SetRandom()
  179. b.ResetTimer()
  180. for i := 0; i < b.N; i++ {
  181. benchResElement.Double(&benchResElement)
  182. }
  183. }
  184. func BenchmarkAddELEMENT(b *testing.B) {
  185. var x Element
  186. x.SetRandom()
  187. benchResElement.SetRandom()
  188. b.ResetTimer()
  189. for i := 0; i < b.N; i++ {
  190. benchResElement.Add(&x, &benchResElement)
  191. }
  192. }
  193. func BenchmarkSubELEMENT(b *testing.B) {
  194. var x Element
  195. x.SetRandom()
  196. benchResElement.SetRandom()
  197. b.ResetTimer()
  198. for i := 0; i < b.N; i++ {
  199. benchResElement.Sub(&x, &benchResElement)
  200. }
  201. }
  202. func BenchmarkNegELEMENT(b *testing.B) {
  203. benchResElement.SetRandom()
  204. b.ResetTimer()
  205. for i := 0; i < b.N; i++ {
  206. benchResElement.Neg(&benchResElement)
  207. }
  208. }
  209. func BenchmarkDivELEMENT(b *testing.B) {
  210. var x Element
  211. x.SetRandom()
  212. benchResElement.SetRandom()
  213. b.ResetTimer()
  214. for i := 0; i < b.N; i++ {
  215. benchResElement.Div(&x, &benchResElement)
  216. }
  217. }
  218. func BenchmarkFromMontELEMENT(b *testing.B) {
  219. benchResElement.SetRandom()
  220. b.ResetTimer()
  221. for i := 0; i < b.N; i++ {
  222. benchResElement.FromMont()
  223. }
  224. }
  225. func BenchmarkToMontELEMENT(b *testing.B) {
  226. benchResElement.SetRandom()
  227. b.ResetTimer()
  228. for i := 0; i < b.N; i++ {
  229. benchResElement.ToMont()
  230. }
  231. }
  232. func BenchmarkSquareELEMENT(b *testing.B) {
  233. benchResElement.SetRandom()
  234. b.ResetTimer()
  235. for i := 0; i < b.N; i++ {
  236. benchResElement.Square(&benchResElement)
  237. }
  238. }
  239. func BenchmarkSqrtELEMENT(b *testing.B) {
  240. var a Element
  241. a.SetRandom()
  242. b.ResetTimer()
  243. for i := 0; i < b.N; i++ {
  244. benchResElement.Sqrt(&a)
  245. }
  246. }
  247. func BenchmarkMulAssignELEMENT(b *testing.B) {
  248. x := Element{
  249. 1997599621687373223,
  250. 6052339484930628067,
  251. 10108755138030829701,
  252. 150537098327114917,
  253. }
  254. benchResElement.SetOne()
  255. b.ResetTimer()
  256. for i := 0; i < b.N; i++ {
  257. benchResElement.MulAssign(&x)
  258. }
  259. }
  260. func BenchmarkMulAssignASMELEMENT(b *testing.B) {
  261. x := Element{
  262. 1997599621687373223,
  263. 6052339484930628067,
  264. 10108755138030829701,
  265. 150537098327114917,
  266. }
  267. benchResElement.SetOne()
  268. b.ResetTimer()
  269. for i := 0; i < b.N; i++ {
  270. MulAssignElement(&benchResElement, &x)
  271. }
  272. }
  273. func TestELEMENTAsm(t *testing.T) {
  274. // ensure ASM implementations matches the ones using math/bits
  275. modulus, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10)
  276. for i := 0; i < 500; i++ {
  277. // sample 2 random big int
  278. b1, _ := rand.Int(rand.Reader, modulus)
  279. b2, _ := rand.Int(rand.Reader, modulus)
  280. // e1 = mont(b1), e2 = mont(b2)
  281. var e1, e2, eTestMul, eMulAssign, eSquare, eTestSquare Element
  282. e1.SetBigInt(b1)
  283. e2.SetBigInt(b2)
  284. eTestMul = e1
  285. eTestMul.testMulAssign(&e2)
  286. eMulAssign = e1
  287. eMulAssign.MulAssign(&e2)
  288. if !eTestMul.Equal(&eMulAssign) {
  289. t.Fatal("inconsisntencies between MulAssign and testMulAssign --> check if MulAssign is calling ASM implementaiton on amd64")
  290. }
  291. // square
  292. eSquare.Square(&e1)
  293. eTestSquare.testSquare(&e1)
  294. if !eTestSquare.Equal(&eSquare) {
  295. t.Fatal("inconsisntencies between Square and testSquare --> check if Square is calling ASM implementaiton on amd64")
  296. }
  297. }
  298. }
  299. // this is here for consistency purposes, to ensure MulAssign on AMD64 using asm implementation gives consistent results
  300. func (z *Element) testMulAssign(x *Element) *Element {
  301. var t [4]uint64
  302. var c [3]uint64
  303. {
  304. // round 0
  305. v := z[0]
  306. c[1], c[0] = bits.Mul64(v, x[0])
  307. m := c[0] * 14042775128853446655
  308. c[2] = madd0(m, 4891460686036598785, c[0])
  309. c[1], c[0] = madd1(v, x[1], c[1])
  310. c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0])
  311. c[1], c[0] = madd1(v, x[2], c[1])
  312. c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0])
  313. c[1], c[0] = madd1(v, x[3], c[1])
  314. t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1])
  315. }
  316. {
  317. // round 1
  318. v := z[1]
  319. c[1], c[0] = madd1(v, x[0], t[0])
  320. m := c[0] * 14042775128853446655
  321. c[2] = madd0(m, 4891460686036598785, c[0])
  322. c[1], c[0] = madd2(v, x[1], c[1], t[1])
  323. c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0])
  324. c[1], c[0] = madd2(v, x[2], c[1], t[2])
  325. c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0])
  326. c[1], c[0] = madd2(v, x[3], c[1], t[3])
  327. t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1])
  328. }
  329. {
  330. // round 2
  331. v := z[2]
  332. c[1], c[0] = madd1(v, x[0], t[0])
  333. m := c[0] * 14042775128853446655
  334. c[2] = madd0(m, 4891460686036598785, c[0])
  335. c[1], c[0] = madd2(v, x[1], c[1], t[1])
  336. c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0])
  337. c[1], c[0] = madd2(v, x[2], c[1], t[2])
  338. c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0])
  339. c[1], c[0] = madd2(v, x[3], c[1], t[3])
  340. t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1])
  341. }
  342. {
  343. // round 3
  344. v := z[3]
  345. c[1], c[0] = madd1(v, x[0], t[0])
  346. m := c[0] * 14042775128853446655
  347. c[2] = madd0(m, 4891460686036598785, c[0])
  348. c[1], c[0] = madd2(v, x[1], c[1], t[1])
  349. c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0])
  350. c[1], c[0] = madd2(v, x[2], c[1], t[2])
  351. c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0])
  352. c[1], c[0] = madd2(v, x[3], c[1], t[3])
  353. z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1])
  354. }
  355. // if z > q --> z -= q
  356. // note: this is NOT constant time
  357. if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) {
  358. var b uint64
  359. z[0], b = bits.Sub64(z[0], 4891460686036598785, 0)
  360. z[1], b = bits.Sub64(z[1], 2896914383306846353, b)
  361. z[2], b = bits.Sub64(z[2], 13281191951274694749, b)
  362. z[3], _ = bits.Sub64(z[3], 3486998266802970665, b)
  363. }
  364. return z
  365. }
  366. // this is here for consistency purposes, to ensure Square on AMD64 using asm implementation gives consistent results
  367. func (z *Element) testSquare(x *Element) *Element {
  368. var p [4]uint64
  369. var u, v uint64
  370. {
  371. // round 0
  372. u, p[0] = bits.Mul64(x[0], x[0])
  373. m := p[0] * 14042775128853446655
  374. C := madd0(m, 4891460686036598785, p[0])
  375. var t uint64
  376. t, u, v = madd1sb(x[0], x[1], u)
  377. C, p[0] = madd2(m, 2896914383306846353, v, C)
  378. t, u, v = madd1s(x[0], x[2], t, u)
  379. C, p[1] = madd2(m, 13281191951274694749, v, C)
  380. _, u, v = madd1s(x[0], x[3], t, u)
  381. p[3], p[2] = madd3(m, 3486998266802970665, v, C, u)
  382. }
  383. {
  384. // round 1
  385. m := p[0] * 14042775128853446655
  386. C := madd0(m, 4891460686036598785, p[0])
  387. u, v = madd1(x[1], x[1], p[1])
  388. C, p[0] = madd2(m, 2896914383306846353, v, C)
  389. var t uint64
  390. t, u, v = madd2sb(x[1], x[2], p[2], u)
  391. C, p[1] = madd2(m, 13281191951274694749, v, C)
  392. _, u, v = madd2s(x[1], x[3], p[3], t, u)
  393. p[3], p[2] = madd3(m, 3486998266802970665, v, C, u)
  394. }
  395. {
  396. // round 2
  397. m := p[0] * 14042775128853446655
  398. C := madd0(m, 4891460686036598785, p[0])
  399. C, p[0] = madd2(m, 2896914383306846353, p[1], C)
  400. u, v = madd1(x[2], x[2], p[2])
  401. C, p[1] = madd2(m, 13281191951274694749, v, C)
  402. _, u, v = madd2sb(x[2], x[3], p[3], u)
  403. p[3], p[2] = madd3(m, 3486998266802970665, v, C, u)
  404. }
  405. {
  406. // round 3
  407. m := p[0] * 14042775128853446655
  408. C := madd0(m, 4891460686036598785, p[0])
  409. C, z[0] = madd2(m, 2896914383306846353, p[1], C)
  410. C, z[1] = madd2(m, 13281191951274694749, p[2], C)
  411. u, v = madd1(x[3], x[3], p[3])
  412. z[3], z[2] = madd3(m, 3486998266802970665, v, C, u)
  413. }
  414. // if z > q --> z -= q
  415. // note: this is NOT constant time
  416. if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) {
  417. var b uint64
  418. z[0], b = bits.Sub64(z[0], 4891460686036598785, 0)
  419. z[1], b = bits.Sub64(z[1], 2896914383306846353, b)
  420. z[2], b = bits.Sub64(z[2], 13281191951274694749, b)
  421. z[3], _ = bits.Sub64(z[3], 3486998266802970665, b)
  422. }
  423. return z
  424. }