You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

386 lines
11 KiB

5 years ago
5 years ago
  1. package snark
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "math/big"
  7. "strings"
  8. "testing"
  9. "time"
  10. "github.com/arnaucube/go-snark/circuitcompiler"
  11. "github.com/arnaucube/go-snark/r1csqap"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestZkFromFlatCircuitCode(t *testing.T) {
  15. // compile circuit and get the R1CS
  16. // circuit function
  17. // y = x^3 + x + 5
  18. flatCode := `
  19. func test(private s0, public s1):
  20. s2 = s0 * s0
  21. s3 = s2 * s0
  22. s4 = s3 + s0
  23. s5 = s4 + 5
  24. equals(s1, s5)
  25. out = 1 * 1
  26. `
  27. fmt.Print("\nflat code of the circuit:")
  28. fmt.Println(flatCode)
  29. // parse the code
  30. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  31. circuit, err := parser.Parse()
  32. assert.Nil(t, err)
  33. fmt.Println("\ncircuit data:", circuit)
  34. circuitJson, _ := json.Marshal(circuit)
  35. fmt.Println("circuit:", string(circuitJson))
  36. b3 := big.NewInt(int64(3))
  37. privateInputs := []*big.Int{b3}
  38. b35 := big.NewInt(int64(35))
  39. publicSignals := []*big.Int{b35}
  40. // wittness
  41. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  42. assert.Nil(t, err)
  43. fmt.Println("\n", circuit.Signals)
  44. fmt.Println("witness", w)
  45. // flat code to R1CS
  46. fmt.Println("\ngenerating R1CS from flat code")
  47. a, b, c := circuit.GenerateR1CS()
  48. fmt.Println("\nR1CS:")
  49. fmt.Println("a:", a)
  50. fmt.Println("b:", b)
  51. fmt.Println("c:", c)
  52. // R1CS to QAP
  53. // TODO zxQAP is not used and is an old impl, bad calculated. TODO remove
  54. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  55. fmt.Println("qap")
  56. fmt.Println("alphas", len(alphas))
  57. fmt.Println("alphas[1]", alphas[1])
  58. fmt.Println("betas", len(betas))
  59. fmt.Println("gammas", len(gammas))
  60. fmt.Println("zx length", len(zxQAP))
  61. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  62. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  63. fmt.Println("ax length", len(ax))
  64. fmt.Println("bx length", len(bx))
  65. fmt.Println("cx length", len(cx))
  66. fmt.Println("px length", len(px))
  67. fmt.Println("px[last]", px[0])
  68. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  69. fmt.Println("hx length", len(hxQAP))
  70. // hx==px/zx so px==hx*zx
  71. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  72. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  73. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  74. assert.Equal(t, abc, px)
  75. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  76. assert.Equal(t, abc, hzQAP)
  77. div, rem := Utils.PF.Div(px, zxQAP)
  78. assert.Equal(t, hxQAP, div)
  79. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  80. // calculate trusted setup
  81. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  82. assert.Nil(t, err)
  83. fmt.Println("\nt:", setup.Toxic.T)
  84. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  85. // assert.Equal(t, zxQAP, setup.Pk.Z)
  86. fmt.Println("hx pk.z", hxQAP)
  87. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  88. fmt.Println("hx pk.z", hx)
  89. // assert.Equal(t, hxQAP, hx)
  90. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  91. assert.Equal(t, hx, div)
  92. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  93. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  94. // hx==px/zx so px==hx*zx
  95. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  96. // check length of polynomials H(x) and Z(x)
  97. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  98. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  99. // fmt.Println("pk.Z", len(setup.Pk.Z))
  100. // fmt.Println("zxQAP", len(zxQAP))
  101. proof, err := GenerateProofs(*circuit, setup, w, px)
  102. assert.Nil(t, err)
  103. // fmt.Println("\n proofs:")
  104. // fmt.Println(proof)
  105. // fmt.Println("public signals:", proof.PublicSignals)
  106. fmt.Println("\n", circuit.Signals)
  107. fmt.Println("\nwitness", w)
  108. b35Verif := big.NewInt(int64(35))
  109. publicSignalsVerif := []*big.Int{b35Verif}
  110. before := time.Now()
  111. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true))
  112. fmt.Println("verify proof time elapsed:", time.Since(before))
  113. // check that with another public input the verification returns false
  114. bOtherWrongPublic := big.NewInt(int64(34))
  115. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  116. assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true))
  117. }
  118. func TestZkMultiplication(t *testing.T) {
  119. flatCode := `
  120. func test(private a, private b, public c):
  121. d = a * b
  122. equals(c, d)
  123. out = 1 * 1
  124. `
  125. fmt.Print("\nflat code of the circuit:")
  126. fmt.Println(flatCode)
  127. // parse the code
  128. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  129. circuit, err := parser.Parse()
  130. assert.Nil(t, err)
  131. fmt.Println("\ncircuit data:", circuit)
  132. circuitJson, _ := json.Marshal(circuit)
  133. fmt.Println("circuit:", string(circuitJson))
  134. b3 := big.NewInt(int64(3))
  135. b4 := big.NewInt(int64(4))
  136. privateInputs := []*big.Int{b3, b4}
  137. b12 := big.NewInt(int64(12))
  138. publicSignals := []*big.Int{b12}
  139. // wittness
  140. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  141. assert.Nil(t, err)
  142. fmt.Println("\n", circuit.Signals)
  143. fmt.Println("witness", w)
  144. // flat code to R1CS
  145. fmt.Println("\ngenerating R1CS from flat code")
  146. a, b, c := circuit.GenerateR1CS()
  147. fmt.Println("\nR1CS:")
  148. fmt.Println("a:", a)
  149. fmt.Println("b:", b)
  150. fmt.Println("c:", c)
  151. // R1CS to QAP
  152. // TODO zxQAP is not used and is an old impl, bad calculated. TODO remove
  153. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  154. fmt.Println("qap")
  155. fmt.Println("alphas", len(alphas))
  156. fmt.Println("alphas[1]", alphas[1])
  157. fmt.Println("betas", len(betas))
  158. fmt.Println("gammas", len(gammas))
  159. fmt.Println("zx length", len(zxQAP))
  160. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  161. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  162. fmt.Println("ax length", len(ax))
  163. fmt.Println("bx length", len(bx))
  164. fmt.Println("cx length", len(cx))
  165. fmt.Println("px length", len(px))
  166. fmt.Println("px[last]", px[0])
  167. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  168. fmt.Println("hx length", len(hxQAP))
  169. // hx==px/zx so px==hx*zx
  170. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  171. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  172. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  173. assert.Equal(t, abc, px)
  174. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  175. assert.Equal(t, abc, hzQAP)
  176. div, rem := Utils.PF.Div(px, zxQAP)
  177. assert.Equal(t, hxQAP, div)
  178. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  179. // calculate trusted setup
  180. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  181. assert.Nil(t, err)
  182. fmt.Println("\nt:", setup.Toxic.T)
  183. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  184. // assert.Equal(t, zxQAP, setup.Pk.Z)
  185. fmt.Println("hx pk.z", hxQAP)
  186. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  187. fmt.Println("hx pk.z", hx)
  188. // assert.Equal(t, hxQAP, hx)
  189. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  190. assert.Equal(t, hx, div)
  191. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  192. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  193. // hx==px/zx so px==hx*zx
  194. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  195. // check length of polynomials H(x) and Z(x)
  196. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  197. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  198. // fmt.Println("pk.Z", len(setup.Pk.Z))
  199. // fmt.Println("zxQAP", len(zxQAP))
  200. proof, err := GenerateProofs(*circuit, setup, w, px)
  201. assert.Nil(t, err)
  202. // fmt.Println("\n proofs:")
  203. // fmt.Println(proof)
  204. // fmt.Println("public signals:", proof.PublicSignals)
  205. fmt.Println("\n", circuit.Signals)
  206. fmt.Println("\nwitness", w)
  207. b12Verif := big.NewInt(int64(12))
  208. publicSignalsVerif := []*big.Int{b12Verif}
  209. before := time.Now()
  210. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true))
  211. fmt.Println("verify proof time elapsed:", time.Since(before))
  212. // check that with another public input the verification returns false
  213. bOtherWrongPublic := big.NewInt(int64(11))
  214. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  215. assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true))
  216. }
  217. /*
  218. func TestZkFromHardcodedR1CS(t *testing.T) {
  219. b0 := big.NewInt(int64(0))
  220. b1 := big.NewInt(int64(1))
  221. b3 := big.NewInt(int64(3))
  222. b5 := big.NewInt(int64(5))
  223. b9 := big.NewInt(int64(9))
  224. b27 := big.NewInt(int64(27))
  225. b30 := big.NewInt(int64(30))
  226. b35 := big.NewInt(int64(35))
  227. a := [][]*big.Int{
  228. []*big.Int{b0, b0, b1, b0, b0, b0},
  229. []*big.Int{b0, b0, b0, b1, b0, b0},
  230. []*big.Int{b0, b0, b1, b0, b1, b0},
  231. []*big.Int{b5, b0, b0, b0, b0, b1},
  232. }
  233. b := [][]*big.Int{
  234. []*big.Int{b0, b0, b1, b0, b0, b0},
  235. []*big.Int{b0, b0, b1, b0, b0, b0},
  236. []*big.Int{b1, b0, b0, b0, b0, b0},
  237. []*big.Int{b1, b0, b0, b0, b0, b0},
  238. }
  239. c := [][]*big.Int{
  240. []*big.Int{b0, b0, b0, b1, b0, b0},
  241. []*big.Int{b0, b0, b0, b0, b1, b0},
  242. []*big.Int{b0, b0, b0, b0, b0, b1},
  243. []*big.Int{b0, b1, b0, b0, b0, b0},
  244. }
  245. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  246. // wittness = 1, 35, 3, 9, 27, 30
  247. w := []*big.Int{b1, b35, b3, b9, b27, b30}
  248. circuit := circuitcompiler.Circuit{
  249. NVars: 6,
  250. NPublic: 1,
  251. NSignals: len(w),
  252. }
  253. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  254. hx := Utils.PF.DivisorPolynomial(px, zx)
  255. // hx==px/zx so px==hx*zx
  256. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  257. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  258. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  259. assert.Equal(t, abc, px)
  260. hz := Utils.PF.Mul(hx, zx)
  261. assert.Equal(t, abc, hz)
  262. div, rem := Utils.PF.Div(px, zx)
  263. assert.Equal(t, hx, div)
  264. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  265. // calculate trusted setup
  266. setup, err := GenerateTrustedSetup(len(w), circuit, alphas, betas, gammas, zx)
  267. assert.Nil(t, err)
  268. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  269. proof, err := GenerateProofs(circuit, setup, hx, w)
  270. assert.Nil(t, err)
  271. // assert.True(t, VerifyProof(circuit, setup, proof, true))
  272. publicSignals := []*big.Int{b35}
  273. assert.True(t, VerifyProof(circuit, setup, proof, publicSignals, true))
  274. }
  275. func TestZkMultiplication(t *testing.T) {
  276. // compile circuit and get the R1CS
  277. flatCode := `
  278. func test(a, b):
  279. out = a * b
  280. `
  281. // parse the code
  282. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  283. circuit, err := parser.Parse()
  284. assert.Nil(t, err)
  285. b3 := big.NewInt(int64(3))
  286. b4 := big.NewInt(int64(4))
  287. inputs := []*big.Int{b3, b4}
  288. // wittness
  289. w, err := circuit.CalculateWitness(inputs)
  290. assert.Nil(t, err)
  291. // flat code to R1CS
  292. a, b, c := circuit.GenerateR1CS()
  293. // R1CS to QAP
  294. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  295. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  296. hx := Utils.PF.DivisorPolynomial(px, zx)
  297. // hx==px/zx so px==hx*zx
  298. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  299. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  300. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  301. assert.Equal(t, abc, px)
  302. hz := Utils.PF.Mul(hx, zx)
  303. assert.Equal(t, abc, hz)
  304. div, rem := Utils.PF.Div(px, zx)
  305. assert.Equal(t, hx, div)
  306. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  307. // calculate trusted setup
  308. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  309. assert.Nil(t, err)
  310. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  311. proof, err := GenerateProofs(*circuit, setup, hx, w)
  312. assert.Nil(t, err)
  313. // assert.True(t, VerifyProof(*circuit, setup, proof, false))
  314. b35 := big.NewInt(int64(35))
  315. publicSignals := []*big.Int{b35}
  316. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
  317. }
  318. */