You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

346 lines
10 KiB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. package snark
  2. import (
  3. "bytes"
  4. "fmt"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/arnaucube/go-snark/circuitcompiler"
  10. "github.com/arnaucube/go-snark/r1csqap"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. func TestZkFromFlatCircuitCode(t *testing.T) {
  14. // compile circuit and get the R1CS
  15. // circuit function
  16. // y = x^3 + x + 5
  17. code := `
  18. func exp3(private a):
  19. b = a * a
  20. c = a * b
  21. return c
  22. func sum(private a, private b):
  23. c = a + b
  24. return c
  25. func main(private s0, public s1):
  26. s3 = exp3(s0)
  27. s4 = sum(s3, s0)
  28. s5 = s4 + 5
  29. equals(s1, s5)
  30. out = 1 * 1
  31. `
  32. // the same code without the functions calling, all in one func
  33. // code := `
  34. // func test(private s0, public s1):
  35. // s2 = s0 * s0
  36. // s3 = s2 * s0
  37. // s4 = s3 + s0
  38. // s5 = s4 + 5
  39. // equals(s1, s5)
  40. // out = 1 * 1
  41. // `
  42. fmt.Print("\ncode of the circuit:")
  43. fmt.Println(code)
  44. // parse the code
  45. parser := circuitcompiler.NewParser(strings.NewReader(code))
  46. circuit, err := parser.Parse()
  47. assert.Nil(t, err)
  48. // fmt.Println("\ncircuit data:", circuit)
  49. // circuitJson, _ := json.Marshal(circuit)
  50. // fmt.Println("circuit:", string(circuitJson))
  51. b3 := big.NewInt(int64(3))
  52. privateInputs := []*big.Int{b3}
  53. b35 := big.NewInt(int64(35))
  54. publicSignals := []*big.Int{b35}
  55. // wittness
  56. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  57. assert.Nil(t, err)
  58. // code to R1CS
  59. fmt.Println("\ngenerating R1CS from code")
  60. a, b, c := circuit.GenerateR1CS()
  61. fmt.Println("\nR1CS:")
  62. fmt.Println("a:", a)
  63. fmt.Println("b:", b)
  64. fmt.Println("c:", c)
  65. // R1CS to QAP
  66. // TODO zxQAP is not used and is an old impl, TODO remove
  67. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  68. fmt.Println("qap")
  69. assert.Equal(t, 8, len(alphas))
  70. assert.Equal(t, 8, len(alphas))
  71. assert.Equal(t, 8, len(alphas))
  72. assert.Equal(t, 7, len(zxQAP))
  73. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  74. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  75. assert.Equal(t, 7, len(ax))
  76. assert.Equal(t, 7, len(bx))
  77. assert.Equal(t, 7, len(cx))
  78. assert.Equal(t, 13, len(px))
  79. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  80. assert.Equal(t, 7, len(hxQAP))
  81. // hx==px/zx so px==hx*zx
  82. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  83. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  84. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  85. assert.Equal(t, abc, px)
  86. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  87. assert.Equal(t, abc, hzQAP)
  88. div, rem := Utils.PF.Div(px, zxQAP)
  89. assert.Equal(t, hxQAP, div)
  90. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  91. // calculate trusted setup
  92. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  93. assert.Nil(t, err)
  94. fmt.Println("\nt:", setup.Toxic.T)
  95. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  96. assert.Equal(t, zxQAP, setup.Pk.Z)
  97. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  98. assert.Equal(t, hx, hxQAP)
  99. // assert.Equal(t, hxQAP, hx)
  100. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  101. assert.Equal(t, hx, div)
  102. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  103. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  104. // hx==px/zx so px==hx*zx
  105. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  106. // check length of polynomials H(x) and Z(x)
  107. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  108. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  109. proof, err := GenerateProofs(*circuit, setup, w, px)
  110. assert.Nil(t, err)
  111. // fmt.Println("\n proofs:")
  112. // fmt.Println(proof)
  113. // fmt.Println("public signals:", proof.PublicSignals)
  114. fmt.Println("\nsignals:", circuit.Signals)
  115. fmt.Println("witness:", w)
  116. b35Verif := big.NewInt(int64(35))
  117. publicSignalsVerif := []*big.Int{b35Verif}
  118. before := time.Now()
  119. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true))
  120. fmt.Println("verify proof time elapsed:", time.Since(before))
  121. // check that with another public input the verification returns false
  122. bOtherWrongPublic := big.NewInt(int64(34))
  123. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  124. assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true))
  125. }
  126. func TestZkMultiplication(t *testing.T) {
  127. code := `
  128. func main(private a, private b, public c):
  129. d = a * b
  130. equals(c, d)
  131. out = 1 * 1
  132. `
  133. fmt.Println("code", code)
  134. // parse the code
  135. parser := circuitcompiler.NewParser(strings.NewReader(code))
  136. circuit, err := parser.Parse()
  137. assert.Nil(t, err)
  138. b3 := big.NewInt(int64(3))
  139. b4 := big.NewInt(int64(4))
  140. privateInputs := []*big.Int{b3, b4}
  141. b12 := big.NewInt(int64(12))
  142. publicSignals := []*big.Int{b12}
  143. // wittness
  144. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  145. assert.Nil(t, err)
  146. // code to R1CS
  147. fmt.Println("\ngenerating R1CS from code")
  148. a, b, c := circuit.GenerateR1CS()
  149. fmt.Println("\nR1CS:")
  150. fmt.Println("a:", a)
  151. fmt.Println("b:", b)
  152. fmt.Println("c:", c)
  153. // R1CS to QAP
  154. // TODO zxQAP is not used and is an old impl. TODO remove
  155. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  156. assert.Equal(t, 6, len(alphas))
  157. assert.Equal(t, 6, len(betas))
  158. assert.Equal(t, 6, len(betas))
  159. assert.Equal(t, 5, len(zxQAP))
  160. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  161. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  162. assert.Equal(t, 4, len(ax))
  163. assert.Equal(t, 4, len(bx))
  164. assert.Equal(t, 4, len(cx))
  165. assert.Equal(t, 7, len(px))
  166. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  167. assert.Equal(t, 3, len(hxQAP))
  168. // hx==px/zx so px==hx*zx
  169. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  170. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  171. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  172. assert.Equal(t, abc, px)
  173. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  174. assert.Equal(t, abc, hzQAP)
  175. div, rem := Utils.PF.Div(px, zxQAP)
  176. assert.Equal(t, hxQAP, div)
  177. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  178. // calculate trusted setup
  179. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  180. assert.Nil(t, err)
  181. // fmt.Println("\nt:", setup.Toxic.T)
  182. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  183. assert.Equal(t, zxQAP, setup.Pk.Z)
  184. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  185. assert.Equal(t, 3, len(hx))
  186. assert.Equal(t, hx, hxQAP)
  187. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  188. assert.Equal(t, hx, div)
  189. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  190. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  191. // hx==px/zx so px==hx*zx
  192. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  193. // check length of polynomials H(x) and Z(x)
  194. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  195. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  196. proof, err := GenerateProofs(*circuit, setup, w, px)
  197. assert.Nil(t, err)
  198. // fmt.Println("\n proofs:")
  199. // fmt.Println(proof)
  200. // fmt.Println("public signals:", proof.PublicSignals)
  201. fmt.Println("\n", circuit.Signals)
  202. fmt.Println("witness", w)
  203. b12Verif := big.NewInt(int64(12))
  204. publicSignalsVerif := []*big.Int{b12Verif}
  205. before := time.Now()
  206. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true))
  207. fmt.Println("verify proof time elapsed:", time.Since(before))
  208. // check that with another public input the verification returns false
  209. bOtherWrongPublic := big.NewInt(int64(11))
  210. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  211. assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true))
  212. }
  213. func TestMinimalFlow(t *testing.T) {
  214. // circuit function
  215. // y = x^3 + x + 5
  216. code := `
  217. func main(private s0, public s1):
  218. s2 = s0 * s0
  219. s3 = s2 * s0
  220. s4 = s3 + s0
  221. s5 = s4 + 5
  222. equals(s1, s5)
  223. out = 1 * 1
  224. `
  225. fmt.Print("\ncode of the circuit:")
  226. fmt.Println(code)
  227. // parse the code
  228. parser := circuitcompiler.NewParser(strings.NewReader(code))
  229. circuit, err := parser.Parse()
  230. assert.Nil(t, err)
  231. b3 := big.NewInt(int64(3))
  232. privateInputs := []*big.Int{b3}
  233. b35 := big.NewInt(int64(35))
  234. publicSignals := []*big.Int{b35}
  235. // wittness
  236. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  237. assert.Nil(t, err)
  238. // code to R1CS
  239. fmt.Println("\ngenerating R1CS from code")
  240. a, b, c := circuit.GenerateR1CS()
  241. fmt.Println("\nR1CS:")
  242. fmt.Println("a:", a)
  243. fmt.Println("b:", b)
  244. fmt.Println("c:", c)
  245. // R1CS to QAP
  246. // TODO zxQAP is not used and is an old impl, TODO remove
  247. alphas, betas, gammas, _ := Utils.PF.R1CSToQAP(a, b, c)
  248. fmt.Println("qap")
  249. assert.Equal(t, 8, len(alphas))
  250. assert.Equal(t, 8, len(alphas))
  251. assert.Equal(t, 8, len(alphas))
  252. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  253. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  254. assert.Equal(t, 7, len(ax))
  255. assert.Equal(t, 7, len(bx))
  256. assert.Equal(t, 7, len(cx))
  257. assert.Equal(t, 13, len(px))
  258. // calculate trusted setup
  259. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  260. assert.Nil(t, err)
  261. fmt.Println("\nt:", setup.Toxic.T)
  262. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  263. div, rem := Utils.PF.Div(px, setup.Pk.Z)
  264. assert.Equal(t, hx, div)
  265. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  266. // hx==px/zx so px==hx*zx
  267. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  268. // check length of polynomials H(x) and Z(x)
  269. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  270. proof, err := GenerateProofs(*circuit, setup, w, px)
  271. assert.Nil(t, err)
  272. // fmt.Println("\n proofs:")
  273. // fmt.Println(proof)
  274. // fmt.Println("public signals:", proof.PublicSignals)
  275. fmt.Println("\nsignals:", circuit.Signals)
  276. fmt.Println("witness:", w)
  277. b35Verif := big.NewInt(int64(35))
  278. publicSignalsVerif := []*big.Int{b35Verif}
  279. before := time.Now()
  280. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true))
  281. fmt.Println("verify proof time elapsed:", time.Since(before))
  282. // check that with another public input the verification returns false
  283. bOtherWrongPublic := big.NewInt(int64(34))
  284. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  285. assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true))
  286. }