You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

335 lines
9.2 KiB

5 years ago
5 years ago
  1. package snark
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "math/big"
  7. "strings"
  8. "testing"
  9. "time"
  10. "github.com/arnaucube/go-snark/circuitcompiler"
  11. "github.com/arnaucube/go-snark/r1csqap"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestZkFromFlatCircuitCode(t *testing.T) {
  15. // compile circuit and get the R1CS
  16. // circuit function
  17. // y = x^3 + x + 5
  18. flatCode := `
  19. func test(private s0, public s1):
  20. s2 = s0 * s0
  21. s3 = s2 * s0
  22. s4 = s3 + s0
  23. s5 = s4 + 5
  24. s1 = s5 * 1
  25. s5 = s1 * 1
  26. out = 1 * 1
  27. `
  28. fmt.Print("\nflat code of the circuit:")
  29. fmt.Println(flatCode)
  30. // parse the code
  31. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  32. circuit, err := parser.Parse()
  33. assert.Nil(t, err)
  34. fmt.Println("\ncircuit data:", circuit)
  35. circuitJson, _ := json.Marshal(circuit)
  36. fmt.Println("circuit:", string(circuitJson))
  37. b3 := big.NewInt(int64(3))
  38. privateInputs := []*big.Int{b3}
  39. b35 := big.NewInt(int64(35))
  40. publicSignals := []*big.Int{b35}
  41. // wittness
  42. w, err := circuit.CalculateWitness(privateInputs, publicSignals)
  43. assert.Nil(t, err)
  44. fmt.Println("\n", circuit.Signals)
  45. fmt.Println("witness", w)
  46. // flat code to R1CS
  47. fmt.Println("\ngenerating R1CS from flat code")
  48. a, b, c := circuit.GenerateR1CS()
  49. fmt.Println("\nR1CS:")
  50. fmt.Println("a:", a)
  51. fmt.Println("b:", b)
  52. fmt.Println("c:", c)
  53. // R1CS to QAP
  54. // TODO zxQAP is not used and is an old impl, bad calculated. TODO remove
  55. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  56. fmt.Println("qap")
  57. fmt.Println("alphas", len(alphas))
  58. fmt.Println("alphas[1]", alphas[1])
  59. fmt.Println("betas", len(betas))
  60. fmt.Println("gammas", len(gammas))
  61. fmt.Println("zx length", len(zxQAP))
  62. assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes()))
  63. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  64. fmt.Println("ax length", len(ax))
  65. fmt.Println("bx length", len(bx))
  66. fmt.Println("cx length", len(cx))
  67. fmt.Println("px length", len(px))
  68. fmt.Println("px[last]", px[0])
  69. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  70. fmt.Println("hx length", len(hxQAP))
  71. // hx==px/zx so px==hx*zx
  72. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  73. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  74. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  75. assert.Equal(t, abc, px)
  76. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  77. assert.Equal(t, abc, hzQAP)
  78. div, rem := Utils.PF.Div(px, zxQAP)
  79. assert.Equal(t, hxQAP, div)
  80. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  81. // calculate trusted setup
  82. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  83. assert.Nil(t, err)
  84. fmt.Println("\nt:", setup.Toxic.T)
  85. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  86. // assert.Equal(t, zxQAP, setup.Pk.Z)
  87. fmt.Println("hx pk.z", hxQAP)
  88. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  89. fmt.Println("hx pk.z", hx)
  90. // assert.Equal(t, hxQAP, hx)
  91. div, rem = Utils.PF.Div(px, setup.Pk.Z)
  92. assert.Equal(t, hx, div)
  93. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6))
  94. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  95. // hx==px/zx so px==hx*zx
  96. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  97. // check length of polynomials H(x) and Z(x)
  98. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  99. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  100. // fmt.Println("pk.Z", len(setup.Pk.Z))
  101. // fmt.Println("zxQAP", len(zxQAP))
  102. proof, err := GenerateProofs(*circuit, setup, w, px)
  103. assert.Nil(t, err)
  104. // fmt.Println("\n proofs:")
  105. // fmt.Println(proof)
  106. // fmt.Println("public signals:", proof.PublicSignals)
  107. fmt.Println("\nwitness", w)
  108. b35Verif := big.NewInt(int64(35))
  109. publicSignalsVerif := []*big.Int{b35Verif}
  110. before := time.Now()
  111. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true))
  112. fmt.Println("verify proof time elapsed:", time.Since(before))
  113. // check that with another public input the verification returns false
  114. bOtherWrongPublic := big.NewInt(int64(34))
  115. wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic}
  116. assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true))
  117. }
  118. /*
  119. func TestZkMultiplication(t *testing.T) {
  120. // compile circuit and get the R1CS
  121. flatCode := `
  122. func test(a, b):
  123. out = a * b
  124. `
  125. // parse the code
  126. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  127. circuit, err := parser.Parse()
  128. assert.Nil(t, err)
  129. b3 := big.NewInt(int64(3))
  130. b4 := big.NewInt(int64(4))
  131. inputs := []*big.Int{b3, b4}
  132. // wittness
  133. w, err := circuit.CalculateWitness(inputs)
  134. assert.Nil(t, err)
  135. fmt.Println("circuit")
  136. fmt.Println(circuit.NPublic)
  137. // flat code to R1CS
  138. a, b, c := circuit.GenerateR1CS()
  139. fmt.Println("\nR1CS:")
  140. fmt.Println("a:", a)
  141. fmt.Println("b:", b)
  142. fmt.Println("c:", c)
  143. // R1CS to QAP
  144. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  145. fmt.Println("qap")
  146. fmt.Println("alphas", alphas)
  147. fmt.Println("betas", betas)
  148. fmt.Println("gammas", gammas)
  149. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  150. hx := Utils.PF.DivisorPolynomial(px, zx)
  151. // hx==px/zx so px==hx*zx
  152. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  153. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  154. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  155. assert.Equal(t, abc, px)
  156. hz := Utils.PF.Mul(hx, zx)
  157. assert.Equal(t, abc, hz)
  158. div, rem := Utils.PF.Div(px, zx)
  159. assert.Equal(t, hx, div)
  160. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  161. // calculate trusted setup
  162. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  163. assert.Nil(t, err)
  164. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  165. proof, err := GenerateProofs(*circuit, setup, hx, w)
  166. assert.Nil(t, err)
  167. // assert.True(t, VerifyProof(*circuit, setup, proof, false))
  168. b35 := big.NewInt(int64(35))
  169. publicSignals := []*big.Int{b35}
  170. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
  171. }
  172. */
  173. /*
  174. func TestZkFromHardcodedR1CS(t *testing.T) {
  175. b0 := big.NewInt(int64(0))
  176. b1 := big.NewInt(int64(1))
  177. b3 := big.NewInt(int64(3))
  178. b5 := big.NewInt(int64(5))
  179. b9 := big.NewInt(int64(9))
  180. b27 := big.NewInt(int64(27))
  181. b30 := big.NewInt(int64(30))
  182. b35 := big.NewInt(int64(35))
  183. a := [][]*big.Int{
  184. []*big.Int{b0, b0, b1, b0, b0, b0},
  185. []*big.Int{b0, b0, b0, b1, b0, b0},
  186. []*big.Int{b0, b0, b1, b0, b1, b0},
  187. []*big.Int{b5, b0, b0, b0, b0, b1},
  188. }
  189. b := [][]*big.Int{
  190. []*big.Int{b0, b0, b1, b0, b0, b0},
  191. []*big.Int{b0, b0, b1, b0, b0, b0},
  192. []*big.Int{b1, b0, b0, b0, b0, b0},
  193. []*big.Int{b1, b0, b0, b0, b0, b0},
  194. }
  195. c := [][]*big.Int{
  196. []*big.Int{b0, b0, b0, b1, b0, b0},
  197. []*big.Int{b0, b0, b0, b0, b1, b0},
  198. []*big.Int{b0, b0, b0, b0, b0, b1},
  199. []*big.Int{b0, b1, b0, b0, b0, b0},
  200. }
  201. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  202. // wittness = 1, 35, 3, 9, 27, 30
  203. w := []*big.Int{b1, b35, b3, b9, b27, b30}
  204. circuit := circuitcompiler.Circuit{
  205. NVars: 6,
  206. NPublic: 1,
  207. NSignals: len(w),
  208. }
  209. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  210. hx := Utils.PF.DivisorPolynomial(px, zx)
  211. // hx==px/zx so px==hx*zx
  212. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  213. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  214. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  215. assert.Equal(t, abc, px)
  216. hz := Utils.PF.Mul(hx, zx)
  217. assert.Equal(t, abc, hz)
  218. div, rem := Utils.PF.Div(px, zx)
  219. assert.Equal(t, hx, div)
  220. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  221. // calculate trusted setup
  222. setup, err := GenerateTrustedSetup(len(w), circuit, alphas, betas, gammas, zx)
  223. assert.Nil(t, err)
  224. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  225. proof, err := GenerateProofs(circuit, setup, hx, w)
  226. assert.Nil(t, err)
  227. // assert.True(t, VerifyProof(circuit, setup, proof, true))
  228. publicSignals := []*big.Int{b35}
  229. assert.True(t, VerifyProof(circuit, setup, proof, publicSignals, true))
  230. }
  231. func TestZkMultiplication(t *testing.T) {
  232. // compile circuit and get the R1CS
  233. flatCode := `
  234. func test(a, b):
  235. out = a * b
  236. `
  237. // parse the code
  238. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  239. circuit, err := parser.Parse()
  240. assert.Nil(t, err)
  241. b3 := big.NewInt(int64(3))
  242. b4 := big.NewInt(int64(4))
  243. inputs := []*big.Int{b3, b4}
  244. // wittness
  245. w, err := circuit.CalculateWitness(inputs)
  246. assert.Nil(t, err)
  247. // flat code to R1CS
  248. a, b, c := circuit.GenerateR1CS()
  249. // R1CS to QAP
  250. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  251. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  252. hx := Utils.PF.DivisorPolynomial(px, zx)
  253. // hx==px/zx so px==hx*zx
  254. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  255. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  256. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  257. assert.Equal(t, abc, px)
  258. hz := Utils.PF.Mul(hx, zx)
  259. assert.Equal(t, abc, hz)
  260. div, rem := Utils.PF.Div(px, zx)
  261. assert.Equal(t, hx, div)
  262. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  263. // calculate trusted setup
  264. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  265. assert.Nil(t, err)
  266. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  267. proof, err := GenerateProofs(*circuit, setup, hx, w)
  268. assert.Nil(t, err)
  269. // assert.True(t, VerifyProof(*circuit, setup, proof, false))
  270. b35 := big.NewInt(int64(35))
  271. publicSignals := []*big.Int{b35}
  272. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
  273. }
  274. */