You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

440 lines
13 KiB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. package snark
  2. import (
  3. "bytes"
  4. "fmt"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/arnaucube/go-snark/circuitcompiler"
  10. "github.com/arnaucube/go-snark/groth16"
  11. "github.com/arnaucube/go-snark/r1csqap"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestGroth16MinimalFlow(t *testing.T) {
  15. fmt.Println("testing Groth16 minimal flow")
  16. // circuit function
  17. // y = x^3 + x + 5
  18. code := `
  19. func main(private s0, public s1):
  20. s2 = s0 * s0
  21. s3 = s2 * s0
  22. s4 = s3 + s0
  23. s5 = s4 + 5
  24. equals(s1, s5)
  25. out = 1 * 1
  26. `
  27. fmt.Print("\ncode of the circuit:")
  28. // parse the code
  29. parser := circuitcompiler.NewParser(strings.NewReader(code))
  30. circuit, err := parser.Parse()
  31. assert.Nil(t, err)
  32. b3 := big.NewInt(int64(3))
  33. privateInputs := []*big.Int{b3}
  34. b35 := big.NewInt(int64(35))
  35. publicSignals := []*big.Int{b35}
  36. // wittness
  37. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  38. assert.Nil(t, err)
  39. // code to R1CS
  40. fmt.Println("\ngenerating R1CS from code")
  41. a, b, c := circuit.GenerateR1CS()
  42. fmt.Println("\nR1CS:")
  43. fmt.Println("a:", a)
  44. fmt.Println("b:", b)
  45. fmt.Println("c:", c)
  46. // R1CS to QAP
  47. // TODO zxQAP is not used and is an old impl, TODO remove
  48. alphas, betas, gammas, _ := Utils.PF.R1CSToQAP(a, b, c)
  49. fmt.Println("qap")
  50. assert.Equal(t, 8, len(alphas))
  51. assert.Equal(t, 8, len(alphas))
  52. assert.Equal(t, 8, len(alphas))
  53. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  54. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  55. assert.Equal(t, 7, len(ax))
  56. assert.Equal(t, 7, len(bx))
  57. assert.Equal(t, 7, len(cx))
  58. assert.Equal(t, 13, len(px))
  59. // ---
  60. // from here is the GROTH16
  61. // ---
  62. // calculate trusted setup
  63. fmt.Println("groth")
  64. setup, err := groth16.GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  65. assert.Nil(t, err)
  66. fmt.Println("\nt:", setup.Toxic.T)
  67. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  68. div, rem := Utils.PF.Div(px, setup.Pk.Z)
  69. assert.Equal(t, hx, div)
  70. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  71. // hx==px/zx so px==hx*zx
  72. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  73. // check length of polynomials H(x) and Z(x)
  74. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  75. proof, err := groth16.GenerateProofs(*circuit, setup, w, px)
  76. assert.Nil(t, err)
  77. // fmt.Println("\n proofs:")
  78. // fmt.Println(proof)
  79. // fmt.Println("public signals:", proof.PublicSignals)
  80. fmt.Println("\nsignals:", circuit.Signals)
  81. fmt.Println("witness:", w)
  82. b35Verif := big.NewInt(int64(35))
  83. publicSignalsVerif := []*big.Int{b35Verif}
  84. before := time.Now()
  85. assert.True(t, groth16.VerifyProof(setup, proof, publicSignalsVerif, true))
  86. fmt.Println("verify proof time elapsed:", time.Since(before))
  87. // check that with another public input the verification returns false
  88. bOtherWrongPublic := big.NewInt(int64(34))
  89. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  90. assert.True(t, !groth16.VerifyProof(setup, proof, wrongPublicSignalsVerif, false))
  91. }
  92. func TestZkFromFlatCircuitCode(t *testing.T) {
  93. // compile circuit and get the R1CS
  94. // circuit function
  95. // y = x^3 + x + 5
  96. code := `
  97. func exp3(private a):
  98. b = a * a
  99. c = a * b
  100. return c
  101. func sum(private a, private b):
  102. c = a + b
  103. return c
  104. func main(private s0, public s1):
  105. s3 = exp3(s0)
  106. s4 = sum(s3, s0)
  107. s5 = s4 + 5
  108. equals(s1, s5)
  109. out = 1 * 1
  110. `
  111. // the same code without the functions calling, all in one func
  112. // code := `
  113. // func test(private s0, public s1):
  114. // s2 = s0 * s0
  115. // s3 = s2 * s0
  116. // s4 = s3 + s0
  117. // s5 = s4 + 5
  118. // equals(s1, s5)
  119. // out = 1 * 1
  120. // `
  121. fmt.Print("\ncode of the circuit:")
  122. fmt.Println(code)
  123. // parse the code
  124. parser := circuitcompiler.NewParser(strings.NewReader(code))
  125. circuit, err := parser.Parse()
  126. assert.Nil(t, err)
  127. // fmt.Println("\ncircuit data:", circuit)
  128. // circuitJson, _ := json.Marshal(circuit)
  129. // fmt.Println("circuit:", string(circuitJson))
  130. b3 := big.NewInt(int64(3))
  131. privateInputs := []*big.Int{b3}
  132. b35 := big.NewInt(int64(35))
  133. publicSignals := []*big.Int{b35}
  134. // wittness
  135. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  136. assert.Nil(t, err)
  137. // code to R1CS
  138. fmt.Println("\ngenerating R1CS from code")
  139. a, b, c := circuit.GenerateR1CS()
  140. fmt.Println("\nR1CS:")
  141. fmt.Println("a:", a)
  142. fmt.Println("b:", b)
  143. fmt.Println("c:", c)
  144. // R1CS to QAP
  145. // TODO zxQAP is not used and is an old impl, TODO remove
  146. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  147. fmt.Println("qap")
  148. assert.Equal(t, 8, len(alphas))
  149. assert.Equal(t, 8, len(alphas))
  150. assert.Equal(t, 8, len(alphas))
  151. assert.Equal(t, 7, len(zxQAP))
  152. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  153. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  154. assert.Equal(t, 7, len(ax))
  155. assert.Equal(t, 7, len(bx))
  156. assert.Equal(t, 7, len(cx))
  157. assert.Equal(t, 13, len(px))
  158. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  159. assert.Equal(t, 7, len(hxQAP))
  160. // hx==px/zx so px==hx*zx
  161. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  162. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  163. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  164. assert.Equal(t, abc, px)
  165. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  166. assert.Equal(t, abc, hzQAP)
  167. div, rem := Utils.PF.Div(px, zxQAP)
  168. assert.Equal(t, hxQAP, div)
  169. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  170. // calculate trusted setup
  171. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  172. assert.Nil(t, err)
  173. fmt.Println("\nt:", setup.Toxic.T)
  174. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  175. assert.Equal(t, zxQAP, setup.Pk.Z)
  176. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  177. assert.Equal(t, hx, hxQAP)
  178. // assert.Equal(t, hxQAP, hx)
  179. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  180. assert.Equal(t, hx, div)
  181. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  182. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  183. // hx==px/zx so px==hx*zx
  184. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  185. // check length of polynomials H(x) and Z(x)
  186. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  187. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  188. proof, err := GenerateProofs(*circuit, setup, w, px)
  189. assert.Nil(t, err)
  190. // fmt.Println("\n proofs:")
  191. // fmt.Println(proof)
  192. // fmt.Println("public signals:", proof.PublicSignals)
  193. fmt.Println("\nsignals:", circuit.Signals)
  194. fmt.Println("witness:", w)
  195. b35Verif := big.NewInt(int64(35))
  196. publicSignalsVerif := []*big.Int{b35Verif}
  197. before := time.Now()
  198. assert.True(t, VerifyProof(setup, proof, publicSignalsVerif, true))
  199. fmt.Println("verify proof time elapsed:", time.Since(before))
  200. // check that with another public input the verification returns false
  201. bOtherWrongPublic := big.NewInt(int64(34))
  202. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  203. assert.True(t, !VerifyProof(setup, proof, wrongPublicSignalsVerif, false))
  204. }
  205. func TestZkMultiplication(t *testing.T) {
  206. code := `
  207. func main(private a, private b, public c):
  208. d = a * b
  209. equals(c, d)
  210. out = 1 * 1
  211. `
  212. fmt.Println("code", code)
  213. // parse the code
  214. parser := circuitcompiler.NewParser(strings.NewReader(code))
  215. circuit, err := parser.Parse()
  216. assert.Nil(t, err)
  217. b3 := big.NewInt(int64(3))
  218. b4 := big.NewInt(int64(4))
  219. privateInputs := []*big.Int{b3, b4}
  220. b12 := big.NewInt(int64(12))
  221. publicSignals := []*big.Int{b12}
  222. // wittness
  223. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  224. assert.Nil(t, err)
  225. // code to R1CS
  226. fmt.Println("\ngenerating R1CS from code")
  227. a, b, c := circuit.GenerateR1CS()
  228. fmt.Println("\nR1CS:")
  229. fmt.Println("a:", a)
  230. fmt.Println("b:", b)
  231. fmt.Println("c:", c)
  232. // R1CS to QAP
  233. // TODO zxQAP is not used and is an old impl. TODO remove
  234. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  235. assert.Equal(t, 6, len(alphas))
  236. assert.Equal(t, 6, len(betas))
  237. assert.Equal(t, 6, len(betas))
  238. assert.Equal(t, 5, len(zxQAP))
  239. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  240. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  241. assert.Equal(t, 4, len(ax))
  242. assert.Equal(t, 4, len(bx))
  243. assert.Equal(t, 4, len(cx))
  244. assert.Equal(t, 7, len(px))
  245. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  246. assert.Equal(t, 3, len(hxQAP))
  247. // hx==px/zx so px==hx*zx
  248. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  249. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  250. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  251. assert.Equal(t, abc, px)
  252. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  253. assert.Equal(t, abc, hzQAP)
  254. div, rem := Utils.PF.Div(px, zxQAP)
  255. assert.Equal(t, hxQAP, div)
  256. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  257. // calculate trusted setup
  258. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  259. assert.Nil(t, err)
  260. // fmt.Println("\nt:", setup.Toxic.T)
  261. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  262. assert.Equal(t, zxQAP, setup.Pk.Z)
  263. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  264. assert.Equal(t, 3, len(hx))
  265. assert.Equal(t, hx, hxQAP)
  266. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  267. assert.Equal(t, hx, div)
  268. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  269. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  270. // hx==px/zx so px==hx*zx
  271. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  272. // check length of polynomials H(x) and Z(x)
  273. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  274. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  275. proof, err := GenerateProofs(*circuit, setup, w, px)
  276. assert.Nil(t, err)
  277. // fmt.Println("\n proofs:")
  278. // fmt.Println(proof)
  279. // fmt.Println("public signals:", proof.PublicSignals)
  280. fmt.Println("\n", circuit.Signals)
  281. fmt.Println("witness", w)
  282. b12Verif := big.NewInt(int64(12))
  283. publicSignalsVerif := []*big.Int{b12Verif}
  284. before := time.Now()
  285. assert.True(t, VerifyProof(setup, proof, publicSignalsVerif, true))
  286. fmt.Println("verify proof time elapsed:", time.Since(before))
  287. // check that with another public input the verification returns false
  288. bOtherWrongPublic := big.NewInt(int64(11))
  289. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  290. assert.True(t, !VerifyProof(setup, proof, wrongPublicSignalsVerif, false))
  291. }
  292. func TestMinimalFlow(t *testing.T) {
  293. // circuit function
  294. // y = x^3 + x + 5
  295. code := `
  296. func main(private s0, public s1):
  297. s2 = s0 * s0
  298. s3 = s2 * s0
  299. s4 = s3 + s0
  300. s5 = s4 + 5
  301. equals(s1, s5)
  302. out = 1 * 1
  303. `
  304. fmt.Print("\ncode of the circuit:")
  305. fmt.Println(code)
  306. // parse the code
  307. parser := circuitcompiler.NewParser(strings.NewReader(code))
  308. circuit, err := parser.Parse()
  309. assert.Nil(t, err)
  310. b3 := big.NewInt(int64(3))
  311. privateInputs := []*big.Int{b3}
  312. b35 := big.NewInt(int64(35))
  313. publicSignals := []*big.Int{b35}
  314. // wittness
  315. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  316. assert.Nil(t, err)
  317. // code to R1CS
  318. fmt.Println("\ngenerating R1CS from code")
  319. a, b, c := circuit.GenerateR1CS()
  320. fmt.Println("\nR1CS:")
  321. fmt.Println("a:", a)
  322. fmt.Println("b:", b)
  323. fmt.Println("c:", c)
  324. // R1CS to QAP
  325. // TODO zxQAP is not used and is an old impl, TODO remove
  326. alphas, betas, gammas, _ := Utils.PF.R1CSToQAP(a, b, c)
  327. fmt.Println("qap")
  328. assert.Equal(t, 8, len(alphas))
  329. assert.Equal(t, 8, len(alphas))
  330. assert.Equal(t, 8, len(alphas))
  331. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  332. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  333. assert.Equal(t, 7, len(ax))
  334. assert.Equal(t, 7, len(bx))
  335. assert.Equal(t, 7, len(cx))
  336. assert.Equal(t, 13, len(px))
  337. // calculate trusted setup
  338. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  339. assert.Nil(t, err)
  340. fmt.Println("\nt:", setup.Toxic.T)
  341. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  342. div, rem := Utils.PF.Div(px, setup.Pk.Z)
  343. assert.Equal(t, hx, div)
  344. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  345. // hx==px/zx so px==hx*zx
  346. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  347. // check length of polynomials H(x) and Z(x)
  348. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  349. proof, err := GenerateProofs(*circuit, setup, w, px)
  350. assert.Nil(t, err)
  351. // fmt.Println("\n proofs:")
  352. // fmt.Println(proof)
  353. // fmt.Println("public signals:", proof.PublicSignals)
  354. fmt.Println("\nsignals:", circuit.Signals)
  355. fmt.Println("witness:", w)
  356. b35Verif := big.NewInt(int64(35))
  357. publicSignalsVerif := []*big.Int{b35Verif}
  358. before := time.Now()
  359. assert.True(t, VerifyProof(setup, proof, publicSignalsVerif, true))
  360. fmt.Println("verify proof time elapsed:", time.Since(before))
  361. // check that with another public input the verification returns false
  362. bOtherWrongPublic := big.NewInt(int64(34))
  363. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  364. assert.True(t, !VerifyProof(setup, proof, wrongPublicSignalsVerif, false))
  365. }