You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

350 lines
10 KiB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. package proof
  2. import (
  3. "bytes"
  4. "fmt"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/arnaucube/go-snark/circuit"
  11. "github.com/arnaucube/go-snark/r1csqap"
  12. )
  13. func TestZkFromFlatCircuitCode(t *testing.T) {
  14. // compile circuit and get the R1CS
  15. // circuit function
  16. // y = x^3 + x + 5
  17. code := `
  18. func exp3(private a):
  19. b = a * a
  20. c = a * b
  21. return c
  22. func sum(private a, private b):
  23. c = a + b
  24. return c
  25. func main(private s0, public s1):
  26. s3 = exp3(s0)
  27. s4 = sum(s3, s0)
  28. s5 = s4 + 5
  29. equals(s1, s5)
  30. out = 1 * 1
  31. `
  32. // the same code without the functions calling, all in one func
  33. // code := `
  34. // func test(private s0, public s1):
  35. // s2 = s0 * s0
  36. // s3 = s2 * s0
  37. // s4 = s3 + s0
  38. // s5 = s4 + 5
  39. // equals(s1, s5)
  40. // out = 1 * 1
  41. // `
  42. fmt.Print("\ncode of the circuit:")
  43. fmt.Println(code)
  44. // parse the code
  45. parser := circuit.NewParser(strings.NewReader(code))
  46. circuit, err := parser.Parse()
  47. assert.Nil(t, err)
  48. // fmt.Println("\ncircuit data:", circuit)
  49. // circuitJson, _ := json.Marshal(circuit)
  50. // fmt.Println("circuit:", string(circuitJson))
  51. b3 := big.NewInt(int64(3))
  52. privateInputs := []*big.Int{b3}
  53. b35 := big.NewInt(int64(35))
  54. publicSignals := []*big.Int{b35}
  55. // wittness
  56. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  57. assert.Nil(t, err)
  58. // code to R1CS
  59. fmt.Println("\ngenerating R1CS from code")
  60. a, b, c := circuit.GenerateR1CS()
  61. fmt.Println("\nR1CS:")
  62. fmt.Println("a:", a)
  63. fmt.Println("b:", b)
  64. fmt.Println("c:", c)
  65. // R1CS to QAP
  66. // TODO zxQAP is not used and is an old impl, TODO remove
  67. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  68. fmt.Println("qap")
  69. assert.Equal(t, 8, len(alphas))
  70. assert.Equal(t, 8, len(alphas))
  71. assert.Equal(t, 8, len(alphas))
  72. assert.Equal(t, 7, len(zxQAP))
  73. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  74. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  75. assert.Equal(t, 7, len(ax))
  76. assert.Equal(t, 7, len(bx))
  77. assert.Equal(t, 7, len(cx))
  78. assert.Equal(t, 13, len(px))
  79. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  80. assert.Equal(t, 7, len(hxQAP))
  81. // hx==px/zx so px==hx*zx
  82. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  83. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  84. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  85. assert.Equal(t, abc, px)
  86. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  87. assert.Equal(t, abc, hzQAP)
  88. div, rem := Utils.PF.Div(px, zxQAP)
  89. assert.Equal(t, hxQAP, div)
  90. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  91. // calculate trusted setup
  92. setup := &PinocchioSetup{}
  93. err = setup.Init(len(w), *circuit, alphas, betas, gammas)
  94. assert.Nil(t, err)
  95. fmt.Println("\nt:", setup.Toxic.T)
  96. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  97. assert.Equal(t, zxQAP, setup.Pk.Z)
  98. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  99. assert.Equal(t, hx, hxQAP)
  100. // assert.Equal(t, hxQAP, hx)
  101. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  102. assert.Equal(t, hx, div)
  103. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  104. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  105. // hx==px/zx so px==hx*zx
  106. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  107. // check length of polynomials H(x) and Z(x)
  108. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  109. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  110. proof, err := setup.Generate(*circuit, w, px)
  111. assert.Nil(t, err)
  112. // fmt.Println("\n proofs:")
  113. // fmt.Println(proof)
  114. // fmt.Println("public signals:", proof.PublicSignals)
  115. fmt.Println("\nsignals:", circuit.Signals)
  116. fmt.Println("witness:", w)
  117. b35Verif := big.NewInt(int64(35))
  118. publicSignalsVerif := []*big.Int{b35Verif}
  119. before := time.Now()
  120. assert.True(t, setup.Verify(*circuit, proof, publicSignalsVerif, true))
  121. fmt.Println("verify proof time elapsed:", time.Since(before))
  122. // check that with another public input the verification returns false
  123. bOtherWrongPublic := big.NewInt(int64(34))
  124. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  125. assert.True(t, !setup.Verify(*circuit, proof, wrongPublicSignalsVerif, false))
  126. }
  127. func TestZkMultiplication(t *testing.T) {
  128. code := `
  129. func main(private a, private b, public c):
  130. d = a * b
  131. equals(c, d)
  132. out = 1 * 1
  133. `
  134. fmt.Println("code", code)
  135. // parse the code
  136. parser := circuit.NewParser(strings.NewReader(code))
  137. circuit, err := parser.Parse()
  138. assert.Nil(t, err)
  139. b3 := big.NewInt(int64(3))
  140. b4 := big.NewInt(int64(4))
  141. privateInputs := []*big.Int{b3, b4}
  142. b12 := big.NewInt(int64(12))
  143. publicSignals := []*big.Int{b12}
  144. // wittness
  145. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  146. assert.Nil(t, err)
  147. // code to R1CS
  148. fmt.Println("\ngenerating R1CS from code")
  149. a, b, c := circuit.GenerateR1CS()
  150. fmt.Println("\nR1CS:")
  151. fmt.Println("a:", a)
  152. fmt.Println("b:", b)
  153. fmt.Println("c:", c)
  154. // R1CS to QAP
  155. // TODO zxQAP is not used and is an old impl. TODO remove
  156. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  157. assert.Equal(t, 6, len(alphas))
  158. assert.Equal(t, 6, len(betas))
  159. assert.Equal(t, 6, len(betas))
  160. assert.Equal(t, 5, len(zxQAP))
  161. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  162. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  163. assert.Equal(t, 4, len(ax))
  164. assert.Equal(t, 4, len(bx))
  165. assert.Equal(t, 4, len(cx))
  166. assert.Equal(t, 7, len(px))
  167. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  168. assert.Equal(t, 3, len(hxQAP))
  169. // hx==px/zx so px==hx*zx
  170. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  171. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  172. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  173. assert.Equal(t, abc, px)
  174. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  175. assert.Equal(t, abc, hzQAP)
  176. div, rem := Utils.PF.Div(px, zxQAP)
  177. assert.Equal(t, hxQAP, div)
  178. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  179. // calculate trusted setup
  180. setup := &PinocchioSetup{}
  181. err = setup.Init(len(w), *circuit, alphas, betas, gammas)
  182. assert.Nil(t, err)
  183. // fmt.Println("\nt:", setup.Toxic.T)
  184. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  185. assert.Equal(t, zxQAP, setup.Pk.Z)
  186. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  187. assert.Equal(t, 3, len(hx))
  188. assert.Equal(t, hx, hxQAP)
  189. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  190. assert.Equal(t, hx, div)
  191. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  192. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  193. // hx==px/zx so px==hx*zx
  194. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  195. // check length of polynomials H(x) and Z(x)
  196. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  197. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  198. proof, err := setup.Generate(*circuit, w, px)
  199. assert.Nil(t, err)
  200. // fmt.Println("\n proofs:")
  201. // fmt.Println(proof)
  202. // fmt.Println("public signals:", proof.PublicSignals)
  203. fmt.Println("\n", circuit.Signals)
  204. fmt.Println("witness", w)
  205. b12Verif := big.NewInt(int64(12))
  206. publicSignalsVerif := []*big.Int{b12Verif}
  207. before := time.Now()
  208. assert.True(t, setup.Verify(*circuit, proof, publicSignalsVerif, true))
  209. fmt.Println("verify proof time elapsed:", time.Since(before))
  210. // check that with another public input the verification returns false
  211. bOtherWrongPublic := big.NewInt(int64(11))
  212. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  213. assert.True(t, !setup.Verify(*circuit, proof, wrongPublicSignalsVerif, false))
  214. }
  215. func TestMinimalFlow(t *testing.T) {
  216. // circuit function
  217. // y = x^3 + x + 5
  218. code := `
  219. func main(private s0, public s1):
  220. s2 = s0 * s0
  221. s3 = s2 * s0
  222. s4 = s3 + s0
  223. s5 = s4 + 5
  224. equals(s1, s5)
  225. out = 1 * 1
  226. `
  227. fmt.Print("\ncode of the circuit:")
  228. fmt.Println(code)
  229. // parse the code
  230. parser := circuit.NewParser(strings.NewReader(code))
  231. circuit, err := parser.Parse()
  232. assert.Nil(t, err)
  233. b3 := big.NewInt(int64(3))
  234. privateInputs := []*big.Int{b3}
  235. b35 := big.NewInt(int64(35))
  236. publicSignals := []*big.Int{b35}
  237. // wittness
  238. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  239. assert.Nil(t, err)
  240. // code to R1CS
  241. fmt.Println("\ngenerating R1CS from code")
  242. a, b, c := circuit.GenerateR1CS()
  243. fmt.Println("\nR1CS:")
  244. fmt.Println("a:", a)
  245. fmt.Println("b:", b)
  246. fmt.Println("c:", c)
  247. // R1CS to QAP
  248. // TODO zxQAP is not used and is an old impl, TODO remove
  249. alphas, betas, gammas, _ := Utils.PF.R1CSToQAP(a, b, c)
  250. fmt.Println("qap")
  251. assert.Equal(t, 8, len(alphas))
  252. assert.Equal(t, 8, len(alphas))
  253. assert.Equal(t, 8, len(alphas))
  254. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  255. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  256. assert.Equal(t, 7, len(ax))
  257. assert.Equal(t, 7, len(bx))
  258. assert.Equal(t, 7, len(cx))
  259. assert.Equal(t, 13, len(px))
  260. // calculate trusted setup
  261. setup := &PinocchioSetup{}
  262. err = setup.Init(len(w), *circuit, alphas, betas, gammas)
  263. assert.Nil(t, err)
  264. fmt.Println("\nt:", setup.Toxic.T)
  265. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  266. div, rem := Utils.PF.Div(px, setup.Pk.Z)
  267. assert.Equal(t, hx, div)
  268. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  269. // hx==px/zx so px==hx*zx
  270. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  271. // check length of polynomials H(x) and Z(x)
  272. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  273. proof, err := setup.Generate(*circuit, w, px)
  274. assert.Nil(t, err)
  275. // fmt.Println("\n proofs:")
  276. // fmt.Println(proof)
  277. // fmt.Println("public signals:", proof.PublicSignals)
  278. fmt.Println("\nsignals:", circuit.Signals)
  279. fmt.Println("witness:", w)
  280. b35Verif := big.NewInt(int64(35))
  281. publicSignalsVerif := []*big.Int{b35Verif}
  282. before := time.Now()
  283. assert.True(t, setup.Verify(*circuit, proof, publicSignalsVerif, true))
  284. fmt.Println("verify proof time elapsed:", time.Since(before))
  285. // check that with another public input the verification returns false
  286. bOtherWrongPublic := big.NewInt(int64(34))
  287. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  288. assert.True(t, !setup.Verify(*circuit, proof, wrongPublicSignalsVerif, false))
  289. }