You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
6.0 KiB

  1. package snark
  2. import (
  3. "fmt"
  4. "github.com/arnaucube/go-snark/circuitcompiler"
  5. "github.com/arnaucube/go-snark/r1csqap"
  6. "github.com/stretchr/testify/assert"
  7. "math/big"
  8. "strings"
  9. "testing"
  10. "time"
  11. )
  12. type InOut struct {
  13. inputs []*big.Int
  14. result *big.Int
  15. }
  16. type TraceCorrectnessTest struct {
  17. code string
  18. io []InOut
  19. }
  20. var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
  21. var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
  22. var correctnesTest = []TraceCorrectnessTest{
  23. //{
  24. // io: []InOut{{
  25. // inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  26. // result: big.NewInt(int64(1729500084900343)),
  27. // }, {
  28. // inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  29. //
  30. // result: bigNumberResult1,
  31. // }},
  32. // code: `
  33. //def main( x , z ) :
  34. // out = do(z) + add(x,x)
  35. //
  36. //def do(x):
  37. // e = x * 5
  38. // b = e * 6
  39. // c = b * 7
  40. // f = c * 1
  41. // d = c * f
  42. // out = d * mul(d,e)
  43. //
  44. //def add(x ,k):
  45. // z = k * x
  46. // out = do(x) + mul(x,z)
  47. //
  48. //
  49. //def mul(a,b):
  50. // out = a * b
  51. //`,
  52. //},
  53. //{io: []InOut{{
  54. // inputs: []*big.Int{big.NewInt(int64(7))},
  55. // result: big.NewInt(int64(4)),
  56. //}},
  57. // code: `
  58. //def mul(a,b):
  59. // out = a * b
  60. //
  61. //def main(a):
  62. // b = a * a
  63. // c = 4 - b
  64. // d = 5 * c
  65. // out = mul(d,c) / mul(b,b)
  66. //`,
  67. //},
  68. //{io: []InOut{{
  69. // inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  70. // result: big.NewInt(int64(22638)),
  71. //}, {
  72. // inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  73. // result: bigNumberResult2,
  74. //}},
  75. // code: `
  76. //def main(a,b):
  77. // d = b + b
  78. // c = a * d
  79. // e = c - a
  80. // out = e * c
  81. //`,
  82. //},
  83. //{
  84. // io: []InOut{{
  85. // inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
  86. // result: big.NewInt(int64(98441327276)),
  87. // }, {
  88. // inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  89. // result: big.NewInt(int64(8675445947220)),
  90. // }},
  91. // code: `
  92. //def main(a,b):
  93. // c = a + b
  94. // e = c - a
  95. // f = e + b
  96. // g = f + 2
  97. // out = g * a
  98. //`,
  99. //},
  100. {
  101. io: []InOut{{
  102. inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
  103. result: big.NewInt(int64(444675)),
  104. }},
  105. code: `
  106. def main(a,b,c,d):
  107. e = a * b
  108. f = c * d
  109. g = e * f
  110. h = g / e
  111. i = h * 5
  112. out = g * i
  113. `,
  114. },
  115. }
  116. func TestGenerateAndVerifyProof(t *testing.T) {
  117. for _, test := range correctnesTest {
  118. parser := circuitcompiler.NewParser(strings.NewReader(test.code))
  119. program, err := parser.Parse()
  120. if err != nil {
  121. panic(err)
  122. }
  123. fmt.Println("\n unreduced")
  124. fmt.Println(test.code)
  125. program.BuildConstraintTrees()
  126. program.PrintContraintTrees()
  127. fmt.Println("\nReduced gates")
  128. //PrintTree(froots["mul"])
  129. gates := program.ReduceCombinedTree()
  130. for _, g := range gates {
  131. fmt.Println(g)
  132. }
  133. fmt.Println("generating R1CS")
  134. //NOTE MOVE DOES NOTHING CURRENTLY
  135. r1cs := moveOutputToBegining(program.GenerateReducedR1CS(gates))
  136. //[[0 1 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1 0 0]]
  137. //[[0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 5 0]]
  138. //[[0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 1]]
  139. a, b, c := r1cs.A, r1cs.B, r1cs.C
  140. fmt.Println(a)
  141. fmt.Println(b)
  142. fmt.Println(c)
  143. // R1CS to QAP
  144. alphas, betas, gammas, domain := Utils.PF.R1CSToQAP(a, b, c)
  145. fmt.Println("QAP array lengths")
  146. fmt.Println("alphas", len(alphas))
  147. fmt.Println("betas", len(betas))
  148. fmt.Println("gammas", len(gammas))
  149. fmt.Println("domain polynomial ", len(domain))
  150. before := time.Now()
  151. //calculate trusted setup
  152. setup, err := GenerateTrustedSetup(len(alphas[0]), alphas, betas, gammas)
  153. fmt.Println("Generate CRS time elapsed:", time.Since(before))
  154. assert.Nil(t, err)
  155. fmt.Println("\nt:", setup.Toxic.T)
  156. for _, io := range test.io {
  157. inputs := io.inputs
  158. fmt.Println("input")
  159. fmt.Println(inputs)
  160. w := circuitcompiler.CalculateWitness(inputs, r1cs)
  161. fmt.Println("\nwitness", w)
  162. //NOTE MOVE DOES NOTHING
  163. w = moveWitnessOutputAfterInputs(program.GlobalInputCount(), w)
  164. fmt.Println("\nwitness Reordered ", w)
  165. assert.Equal(t, io.result, w[len(w)-1])
  166. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  167. fmt.Println("ax length", len(ax))
  168. fmt.Println("bx length", len(bx))
  169. fmt.Println("cx length", len(cx))
  170. fmt.Println("px length", len(px))
  171. hxQAP := Utils.PF.DivisorPolynomial(px, domain)
  172. fmt.Println("hx length", len(hxQAP))
  173. // hx==px/zx so px==hx*zx
  174. assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain))
  175. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  176. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  177. assert.Equal(t, abc, px)
  178. div, rem := Utils.PF.Div(px, domain)
  179. assert.Equal(t, hxQAP, div) //not necessary, since DivisorPolynomial is Div, just discarding 'rem'
  180. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(len(px)-len(domain)))
  181. //// zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  182. //assert.Equal(t, domain, setup.Pk.Z)
  183. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  184. // assert.Equal(t, hxQAP, hx)
  185. assert.Equal(t, px, Utils.PF.Mul(hxQAP, domain))
  186. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  187. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  188. assert.Equal(t, len(hxQAP), len(px)-len(domain)+1)
  189. before := time.Now()
  190. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  191. proof, err := GenerateProofs(setup, program.GlobalInputCount(), w, px)
  192. fmt.Println("proof generation time elapsed:", time.Since(before))
  193. assert.Nil(t, err)
  194. before = time.Now()
  195. assert.True(t, VerifyProof(setup, proof, append(w[1:program.GlobalInputCount()], w[len(w)-1]), true))
  196. fmt.Println("verify proof time elapsed:", time.Since(before))
  197. }
  198. }
  199. }