You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

321 lines
8.7 KiB

5 years ago
5 years ago
  1. package snark
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/arnaucube/go-snark/circuitcompiler"
  10. "github.com/arnaucube/go-snark/r1csqap"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. func TestZkFromFlatCircuitCode(t *testing.T) {
  14. // compile circuit and get the R1CS
  15. flatCode := `
  16. func test(x):
  17. aux = x*x
  18. y = aux*x
  19. z = x + y
  20. out = z + 5
  21. `
  22. fmt.Print("\nflat code of the circuit:")
  23. fmt.Println(flatCode)
  24. // parse the code
  25. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  26. circuit, err := parser.Parse()
  27. assert.Nil(t, err)
  28. fmt.Println("\ncircuit data:", circuit)
  29. circuitJson, _ := json.Marshal(circuit)
  30. fmt.Println("circuit:", string(circuitJson))
  31. b3 := big.NewInt(int64(3))
  32. privateInputs := []*big.Int{b3}
  33. // wittness
  34. w, err := circuit.CalculateWitness(privateInputs)
  35. assert.Nil(t, err)
  36. fmt.Println("\nwitness", w)
  37. // flat code to R1CS
  38. fmt.Println("\ngenerating R1CS from flat code")
  39. a, b, c := circuit.GenerateR1CS()
  40. fmt.Println("\nR1CS:")
  41. fmt.Println("a:", a)
  42. fmt.Println("b:", b)
  43. fmt.Println("c:", c)
  44. // R1CS to QAP
  45. // TODO zxQAP is not used and is an old impl, bad calculated. TODO remove
  46. alphas, betas, gammas, zxQAP := Utils.PF.R1CSToQAP(a, b, c)
  47. fmt.Println("qap")
  48. fmt.Println("alphas", len(alphas))
  49. fmt.Println("alphas[1]", alphas[1])
  50. fmt.Println("betas", len(betas))
  51. fmt.Println("gammas", len(gammas))
  52. fmt.Println("zx length", len(zxQAP))
  53. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  54. fmt.Println("ax length", len(ax))
  55. fmt.Println("bx length", len(bx))
  56. fmt.Println("cx length", len(cx))
  57. fmt.Println("px length", len(px))
  58. fmt.Println("px[last]", px[0])
  59. px0 := Utils.PF.F.Add(px[0], big.NewInt(int64(88)))
  60. fmt.Println(px0)
  61. assert.Equal(t, px0.Bytes(), Utils.PF.F.Zero().Bytes())
  62. hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP)
  63. fmt.Println("hx length", len(hxQAP))
  64. // hx==px/zx so px==hx*zx
  65. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  66. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  67. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  68. assert.Equal(t, abc, px)
  69. hzQAP := Utils.PF.Mul(hxQAP, zxQAP)
  70. assert.Equal(t, abc, hzQAP)
  71. div, rem := Utils.PF.Div(px, zxQAP)
  72. assert.Equal(t, hxQAP, div)
  73. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  74. // calculate trusted setup
  75. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas)
  76. assert.Nil(t, err)
  77. fmt.Println("\nt:", setup.Toxic.T)
  78. // zx and setup.Pk.Z should be the same (currently not, the correct one is the calculation used inside GenerateTrustedSetup function), the calculation is repeated. TODO avoid repeating calculation
  79. // assert.Equal(t, zxQAP, setup.Pk.Z)
  80. fmt.Println("hx pk.z", hxQAP)
  81. hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z)
  82. fmt.Println("hx pk.z", hx)
  83. // assert.Equal(t, hxQAP, hx)
  84. assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP))
  85. // hx==px/zx so px==hx*zx
  86. assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z))
  87. // check length of polynomials H(x) and Z(x)
  88. assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1)
  89. assert.Equal(t, len(hxQAP), len(px)-len(zxQAP)+1)
  90. // fmt.Println("pk.Z", len(setup.Pk.Z))
  91. // fmt.Println("zxQAP", len(zxQAP))
  92. proof, err := GenerateProofs(*circuit, setup, w, px)
  93. assert.Nil(t, err)
  94. // fmt.Println("\n proofs:")
  95. // fmt.Println(proof)
  96. // fmt.Println("public signals:", proof.PublicSignals)
  97. fmt.Println("\nwitness", w)
  98. // b1 := big.NewInt(int64(1))
  99. b35 := big.NewInt(int64(35))
  100. // publicSignals := []*big.Int{b1, b35}
  101. publicSignals := []*big.Int{b35}
  102. before := time.Now()
  103. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
  104. fmt.Println("verify proof time elapsed:", time.Since(before))
  105. }
  106. /*
  107. func TestZkMultiplication(t *testing.T) {
  108. // compile circuit and get the R1CS
  109. flatCode := `
  110. func test(a, b):
  111. out = a * b
  112. `
  113. // parse the code
  114. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  115. circuit, err := parser.Parse()
  116. assert.Nil(t, err)
  117. b3 := big.NewInt(int64(3))
  118. b4 := big.NewInt(int64(4))
  119. inputs := []*big.Int{b3, b4}
  120. // wittness
  121. w, err := circuit.CalculateWitness(inputs)
  122. assert.Nil(t, err)
  123. fmt.Println("circuit")
  124. fmt.Println(circuit.NPublic)
  125. // flat code to R1CS
  126. a, b, c := circuit.GenerateR1CS()
  127. fmt.Println("\nR1CS:")
  128. fmt.Println("a:", a)
  129. fmt.Println("b:", b)
  130. fmt.Println("c:", c)
  131. // R1CS to QAP
  132. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  133. fmt.Println("qap")
  134. fmt.Println("alphas", alphas)
  135. fmt.Println("betas", betas)
  136. fmt.Println("gammas", gammas)
  137. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  138. hx := Utils.PF.DivisorPolynomial(px, zx)
  139. // hx==px/zx so px==hx*zx
  140. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  141. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  142. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  143. assert.Equal(t, abc, px)
  144. hz := Utils.PF.Mul(hx, zx)
  145. assert.Equal(t, abc, hz)
  146. div, rem := Utils.PF.Div(px, zx)
  147. assert.Equal(t, hx, div)
  148. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  149. // calculate trusted setup
  150. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  151. assert.Nil(t, err)
  152. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  153. proof, err := GenerateProofs(*circuit, setup, hx, w)
  154. assert.Nil(t, err)
  155. // assert.True(t, VerifyProof(*circuit, setup, proof, false))
  156. b35 := big.NewInt(int64(35))
  157. publicSignals := []*big.Int{b35}
  158. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
  159. }
  160. */
  161. /*
  162. func TestZkFromHardcodedR1CS(t *testing.T) {
  163. b0 := big.NewInt(int64(0))
  164. b1 := big.NewInt(int64(1))
  165. b3 := big.NewInt(int64(3))
  166. b5 := big.NewInt(int64(5))
  167. b9 := big.NewInt(int64(9))
  168. b27 := big.NewInt(int64(27))
  169. b30 := big.NewInt(int64(30))
  170. b35 := big.NewInt(int64(35))
  171. a := [][]*big.Int{
  172. []*big.Int{b0, b0, b1, b0, b0, b0},
  173. []*big.Int{b0, b0, b0, b1, b0, b0},
  174. []*big.Int{b0, b0, b1, b0, b1, b0},
  175. []*big.Int{b5, b0, b0, b0, b0, b1},
  176. }
  177. b := [][]*big.Int{
  178. []*big.Int{b0, b0, b1, b0, b0, b0},
  179. []*big.Int{b0, b0, b1, b0, b0, b0},
  180. []*big.Int{b1, b0, b0, b0, b0, b0},
  181. []*big.Int{b1, b0, b0, b0, b0, b0},
  182. }
  183. c := [][]*big.Int{
  184. []*big.Int{b0, b0, b0, b1, b0, b0},
  185. []*big.Int{b0, b0, b0, b0, b1, b0},
  186. []*big.Int{b0, b0, b0, b0, b0, b1},
  187. []*big.Int{b0, b1, b0, b0, b0, b0},
  188. }
  189. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  190. // wittness = 1, 35, 3, 9, 27, 30
  191. w := []*big.Int{b1, b35, b3, b9, b27, b30}
  192. circuit := circuitcompiler.Circuit{
  193. NVars: 6,
  194. NPublic: 1,
  195. NSignals: len(w),
  196. }
  197. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  198. hx := Utils.PF.DivisorPolynomial(px, zx)
  199. // hx==px/zx so px==hx*zx
  200. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  201. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  202. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  203. assert.Equal(t, abc, px)
  204. hz := Utils.PF.Mul(hx, zx)
  205. assert.Equal(t, abc, hz)
  206. div, rem := Utils.PF.Div(px, zx)
  207. assert.Equal(t, hx, div)
  208. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  209. // calculate trusted setup
  210. setup, err := GenerateTrustedSetup(len(w), circuit, alphas, betas, gammas, zx)
  211. assert.Nil(t, err)
  212. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  213. proof, err := GenerateProofs(circuit, setup, hx, w)
  214. assert.Nil(t, err)
  215. // assert.True(t, VerifyProof(circuit, setup, proof, true))
  216. publicSignals := []*big.Int{b35}
  217. assert.True(t, VerifyProof(circuit, setup, proof, publicSignals, true))
  218. }
  219. func TestZkMultiplication(t *testing.T) {
  220. // compile circuit and get the R1CS
  221. flatCode := `
  222. func test(a, b):
  223. out = a * b
  224. `
  225. // parse the code
  226. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  227. circuit, err := parser.Parse()
  228. assert.Nil(t, err)
  229. b3 := big.NewInt(int64(3))
  230. b4 := big.NewInt(int64(4))
  231. inputs := []*big.Int{b3, b4}
  232. // wittness
  233. w, err := circuit.CalculateWitness(inputs)
  234. assert.Nil(t, err)
  235. // flat code to R1CS
  236. a, b, c := circuit.GenerateR1CS()
  237. // R1CS to QAP
  238. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  239. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  240. hx := Utils.PF.DivisorPolynomial(px, zx)
  241. // hx==px/zx so px==hx*zx
  242. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  243. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  244. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  245. assert.Equal(t, abc, px)
  246. hz := Utils.PF.Mul(hx, zx)
  247. assert.Equal(t, abc, hz)
  248. div, rem := Utils.PF.Div(px, zx)
  249. assert.Equal(t, hx, div)
  250. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  251. // calculate trusted setup
  252. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  253. assert.Nil(t, err)
  254. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  255. proof, err := GenerateProofs(*circuit, setup, hx, w)
  256. assert.Nil(t, err)
  257. // assert.True(t, VerifyProof(*circuit, setup, proof, false))
  258. b35 := big.NewInt(int64(35))
  259. publicSignals := []*big.Int{b35}
  260. assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true))
  261. }
  262. */