You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
5.3 KiB

  1. package prover
  2. import (
  3. "crypto/rand"
  4. "math"
  5. "math/big"
  6. "runtime"
  7. "sync"
  8. bn256 "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare"
  9. "github.com/iden3/go-circom-prover-verifier/types"
  10. "github.com/iden3/go-iden3-crypto/utils"
  11. )
  12. // Proof is the data structure of the Groth16 zkSNARK proof
  13. type Proof struct {
  14. A *bn256.G1
  15. B *bn256.G2
  16. C *bn256.G1
  17. }
  18. // Pk holds the data structure of the ProvingKey
  19. type Pk struct {
  20. A []*bn256.G1
  21. B2 []*bn256.G2
  22. B1 []*bn256.G1
  23. C []*bn256.G1
  24. NVars int
  25. NPublic int
  26. VkAlpha1 *bn256.G1
  27. VkDelta1 *bn256.G1
  28. VkBeta1 *bn256.G1
  29. VkBeta2 *bn256.G2
  30. VkDelta2 *bn256.G2
  31. HExps []*bn256.G1
  32. DomainSize int
  33. PolsA []map[int]*big.Int
  34. PolsB []map[int]*big.Int
  35. PolsC []map[int]*big.Int
  36. }
  37. // Witness contains the witness
  38. type Witness []*big.Int
  39. func randBigInt() (*big.Int, error) {
  40. maxbits := types.R.BitLen()
  41. b := make([]byte, (maxbits/8)-1)
  42. _, err := rand.Read(b)
  43. if err != nil {
  44. return nil, err
  45. }
  46. r := new(big.Int).SetBytes(b)
  47. rq := new(big.Int).Mod(r, types.R)
  48. return rq, nil
  49. }
  50. // GenerateProof generates the Groth16 zkSNARK proof
  51. func GenerateProof(pk *types.Pk, w types.Witness) (*types.Proof, []*big.Int, error) {
  52. var proof types.Proof
  53. r, err := randBigInt()
  54. if err != nil {
  55. return nil, nil, err
  56. }
  57. s, err := randBigInt()
  58. if err != nil {
  59. return nil, nil, err
  60. }
  61. // BEGIN PAR
  62. numcpu := runtime.NumCPU()
  63. proofA := arrayOfZeroesG1(numcpu)
  64. proofB := arrayOfZeroesG2(numcpu)
  65. proofC := arrayOfZeroesG1(numcpu)
  66. proofBG1 := arrayOfZeroesG1(numcpu)
  67. var wg1 sync.WaitGroup
  68. wg1.Add(numcpu)
  69. for _cpu, _ranges := range ranges(pk.NVars, numcpu) {
  70. // split 1
  71. go func(cpu int, ranges [2]int) {
  72. for i := ranges[0]; i < ranges[1]; i++ {
  73. proofA[cpu].Add(proofA[cpu], new(bn256.G1).ScalarMult(pk.A[i], w[i]))
  74. proofB[cpu].Add(proofB[cpu], new(bn256.G2).ScalarMult(pk.B2[i], w[i]))
  75. proofBG1[cpu].Add(proofBG1[cpu], new(bn256.G1).ScalarMult(pk.B1[i], w[i]))
  76. if i >= pk.NPublic+1 {
  77. proofC[cpu].Add(proofC[cpu], new(bn256.G1).ScalarMult(pk.C[i], w[i]))
  78. }
  79. }
  80. wg1.Done()
  81. }(_cpu, _ranges)
  82. }
  83. wg1.Wait()
  84. // join 1
  85. for cpu := 1; cpu < numcpu; cpu++ {
  86. proofA[0].Add(proofA[0], proofA[cpu])
  87. proofB[0].Add(proofB[0], proofB[cpu])
  88. proofC[0].Add(proofC[0], proofC[cpu])
  89. proofBG1[0].Add(proofBG1[0], proofBG1[cpu])
  90. }
  91. proof.A = proofA[0]
  92. proof.B = proofB[0]
  93. proof.C = proofC[0]
  94. // END PAR
  95. h := calculateH(pk, w)
  96. proof.A.Add(proof.A, pk.VkAlpha1)
  97. proof.A.Add(proof.A, new(bn256.G1).ScalarMult(pk.VkDelta1, r))
  98. proof.B.Add(proof.B, pk.VkBeta2)
  99. proof.B.Add(proof.B, new(bn256.G2).ScalarMult(pk.VkDelta2, s))
  100. proofBG1[0].Add(proofBG1[0], pk.VkBeta1)
  101. proofBG1[0].Add(proofBG1[0], new(bn256.G1).ScalarMult(pk.VkDelta1, s))
  102. proofC = arrayOfZeroesG1(numcpu)
  103. var wg2 sync.WaitGroup
  104. wg2.Add(numcpu)
  105. for _cpu, _ranges := range ranges(len(h), numcpu) {
  106. // split 2
  107. go func(cpu int, ranges [2]int) {
  108. for i := ranges[0]; i < ranges[1]; i++ {
  109. proofC[cpu].Add(proofC[cpu], new(bn256.G1).ScalarMult(pk.HExps[i], h[i]))
  110. }
  111. wg2.Done()
  112. }(_cpu, _ranges)
  113. }
  114. wg2.Wait()
  115. // join 2
  116. for cpu := 1; cpu < numcpu; cpu++ {
  117. proofC[0].Add(proofC[0], proofC[cpu])
  118. }
  119. proof.C.Add(proof.C, proofC[0])
  120. proof.C.Add(proof.C, new(bn256.G1).ScalarMult(proof.A, s))
  121. proof.C.Add(proof.C, new(bn256.G1).ScalarMult(proofBG1[0], r))
  122. rsneg := new(big.Int).Mod(new(big.Int).Neg(new(big.Int).Mul(r, s)), types.R)
  123. proof.C.Add(proof.C, new(bn256.G1).ScalarMult(pk.VkDelta1, rsneg))
  124. pubSignals := w[1 : pk.NPublic+1]
  125. return &proof, pubSignals, nil
  126. }
  127. func calculateH(pk *types.Pk, w types.Witness) []*big.Int {
  128. m := pk.DomainSize
  129. polAT := arrayOfZeroes(m)
  130. polBT := arrayOfZeroes(m)
  131. numcpu := runtime.NumCPU()
  132. var wg1 sync.WaitGroup
  133. wg1.Add(2)
  134. go func() {
  135. for i := 0; i < pk.NVars; i++ {
  136. for j := range pk.PolsA[i] {
  137. polAT[j] = fAdd(polAT[j], fMul(w[i], pk.PolsA[i][j]))
  138. }
  139. }
  140. wg1.Done()
  141. }()
  142. go func() {
  143. for i := 0; i < pk.NVars; i++ {
  144. for j := range pk.PolsB[i] {
  145. polBT[j] = fAdd(polBT[j], fMul(w[i], pk.PolsB[i][j]))
  146. }
  147. }
  148. wg1.Done()
  149. }()
  150. wg1.Wait()
  151. polATe := utils.BigIntArrayToElementArray(polAT)
  152. polBTe := utils.BigIntArrayToElementArray(polBT)
  153. polASe := ifft(polATe)
  154. polBSe := ifft(polBTe)
  155. r := int(math.Log2(float64(m))) + 1
  156. roots := newRootsT()
  157. roots.setRoots(r)
  158. var wg2 sync.WaitGroup
  159. wg2.Add(numcpu)
  160. for _cpu, _ranges := range ranges(len(polASe), numcpu) {
  161. go func(cpu int, ranges [2]int) {
  162. for i := ranges[0]; i < ranges[1]; i++ {
  163. polASe[i].Mul(polASe[i], roots.roots[r][i])
  164. polBSe[i].Mul(polBSe[i], roots.roots[r][i])
  165. }
  166. wg2.Done()
  167. }(_cpu, _ranges)
  168. }
  169. wg2.Wait()
  170. polATodd := fft(polASe)
  171. polBTodd := fft(polBSe)
  172. polABT := arrayOfZeroesE(len(polASe) * 2)
  173. var wg3 sync.WaitGroup
  174. wg3.Add(numcpu)
  175. for _cpu, _ranges := range ranges(len(polASe), numcpu) {
  176. go func(cpu int, ranges [2]int) {
  177. for i := ranges[0]; i < ranges[1]; i++ {
  178. polABT[2*i].Mul(polATe[i], polBTe[i])
  179. polABT[2*i+1].Mul(polATodd[i], polBTodd[i])
  180. }
  181. wg3.Done()
  182. }(_cpu, _ranges)
  183. }
  184. wg3.Wait()
  185. hSeFull := ifft(polABT)
  186. hSe := hSeFull[m:]
  187. return utils.ElementArrayToBigIntArray(hSe)
  188. }
  189. func ranges(n, parts int) [][2]int {
  190. s := make([][2]int, parts)
  191. p := float64(n) / float64(parts)
  192. for i := 0; i < parts; i++ {
  193. a, b := int(float64(i)*p), int(float64(i+1)*p)
  194. s[i] = [2]int{a, b}
  195. }
  196. return s
  197. }