You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
5.2 KiB

5 years ago
5 years ago
  1. package snark
  2. import (
  3. "fmt"
  4. "math/big"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/arnaucube/go-snark/circuitcompiler"
  9. "github.com/arnaucube/go-snark/r1csqap"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestZkFromFlatCircuitCode(t *testing.T) {
  13. // compile circuit and get the R1CS
  14. flatCode := `
  15. func test(x):
  16. aux = x*x
  17. y = aux*x
  18. z = x + y
  19. out = z + 5
  20. `
  21. fmt.Print("\nflat code of the circuit:")
  22. fmt.Println(flatCode)
  23. // parse the code
  24. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  25. circuit, err := parser.Parse()
  26. assert.Nil(t, err)
  27. fmt.Println("\ncircuit data:", circuit)
  28. b3 := big.NewInt(int64(3))
  29. inputs := []*big.Int{b3}
  30. // wittness
  31. w, err := circuit.CalculateWitness(inputs)
  32. assert.Nil(t, err)
  33. fmt.Println("\nwitness", w)
  34. // flat code to R1CS
  35. fmt.Println("\ngenerating R1CS from flat code")
  36. a, b, c := circuit.GenerateR1CS()
  37. fmt.Println("\nR1CS:")
  38. fmt.Println("a:", a)
  39. fmt.Println("b:", b)
  40. fmt.Println("c:", c)
  41. // R1CS to QAP
  42. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  43. fmt.Println("qap")
  44. fmt.Println(alphas)
  45. fmt.Println(betas)
  46. fmt.Println(gammas)
  47. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  48. hx := Utils.PF.DivisorPolynomial(px, zx)
  49. // hx==px/zx so px==hx*zx
  50. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  51. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  52. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  53. assert.Equal(t, abc, px)
  54. hz := Utils.PF.Mul(hx, zx)
  55. assert.Equal(t, abc, hz)
  56. div, rem := Utils.PF.Div(px, zx)
  57. assert.Equal(t, hx, div)
  58. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  59. // calculate trusted setup
  60. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  61. assert.Nil(t, err)
  62. fmt.Println("\nt:", setup.Toxic.T)
  63. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  64. proof, err := GenerateProofs(*circuit, setup, hx, w)
  65. assert.Nil(t, err)
  66. fmt.Println("\n proofs:")
  67. fmt.Println(proof)
  68. fmt.Println("public signals:", proof.PublicSignals)
  69. before := time.Now()
  70. assert.True(t, VerifyProof(*circuit, setup, proof, true))
  71. fmt.Println("verify proof time elapsed:", time.Since(before))
  72. }
  73. func TestZkFromHardcodedR1CS(t *testing.T) {
  74. b0 := big.NewInt(int64(0))
  75. b1 := big.NewInt(int64(1))
  76. b3 := big.NewInt(int64(3))
  77. b5 := big.NewInt(int64(5))
  78. b9 := big.NewInt(int64(9))
  79. b27 := big.NewInt(int64(27))
  80. b30 := big.NewInt(int64(30))
  81. b35 := big.NewInt(int64(35))
  82. a := [][]*big.Int{
  83. []*big.Int{b0, b0, b1, b0, b0, b0},
  84. []*big.Int{b0, b0, b0, b1, b0, b0},
  85. []*big.Int{b0, b0, b1, b0, b1, b0},
  86. []*big.Int{b5, b0, b0, b0, b0, b1},
  87. }
  88. b := [][]*big.Int{
  89. []*big.Int{b0, b0, b1, b0, b0, b0},
  90. []*big.Int{b0, b0, b1, b0, b0, b0},
  91. []*big.Int{b1, b0, b0, b0, b0, b0},
  92. []*big.Int{b1, b0, b0, b0, b0, b0},
  93. }
  94. c := [][]*big.Int{
  95. []*big.Int{b0, b0, b0, b1, b0, b0},
  96. []*big.Int{b0, b0, b0, b0, b1, b0},
  97. []*big.Int{b0, b0, b0, b0, b0, b1},
  98. []*big.Int{b0, b1, b0, b0, b0, b0},
  99. }
  100. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  101. // wittness = 1, 35, 3, 9, 27, 30
  102. w := []*big.Int{b1, b35, b3, b9, b27, b30}
  103. circuit := circuitcompiler.Circuit{
  104. NVars: 6,
  105. NPublic: 1,
  106. NSignals: len(w),
  107. }
  108. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  109. hx := Utils.PF.DivisorPolynomial(px, zx)
  110. // hx==px/zx so px==hx*zx
  111. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  112. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  113. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  114. assert.Equal(t, abc, px)
  115. hz := Utils.PF.Mul(hx, zx)
  116. assert.Equal(t, abc, hz)
  117. div, rem := Utils.PF.Div(px, zx)
  118. assert.Equal(t, hx, div)
  119. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  120. // calculate trusted setup
  121. setup, err := GenerateTrustedSetup(len(w), circuit, alphas, betas, gammas, zx)
  122. assert.Nil(t, err)
  123. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  124. proof, err := GenerateProofs(circuit, setup, hx, w)
  125. assert.Nil(t, err)
  126. assert.True(t, VerifyProof(circuit, setup, proof, true))
  127. }
  128. func TestZkMultiplication(t *testing.T) {
  129. // compile circuit and get the R1CS
  130. flatCode := `
  131. func test(a, b):
  132. out = a * b
  133. `
  134. // parse the code
  135. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  136. circuit, err := parser.Parse()
  137. assert.Nil(t, err)
  138. b3 := big.NewInt(int64(3))
  139. b4 := big.NewInt(int64(4))
  140. inputs := []*big.Int{b3, b4}
  141. // wittness
  142. w, err := circuit.CalculateWitness(inputs)
  143. assert.Nil(t, err)
  144. // flat code to R1CS
  145. a, b, c := circuit.GenerateR1CS()
  146. // R1CS to QAP
  147. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  148. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  149. hx := Utils.PF.DivisorPolynomial(px, zx)
  150. // hx==px/zx so px==hx*zx
  151. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  152. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  153. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  154. assert.Equal(t, abc, px)
  155. hz := Utils.PF.Mul(hx, zx)
  156. assert.Equal(t, abc, hz)
  157. div, rem := Utils.PF.Div(px, zx)
  158. assert.Equal(t, hx, div)
  159. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  160. // calculate trusted setup
  161. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  162. assert.Nil(t, err)
  163. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  164. proof, err := GenerateProofs(*circuit, setup, hx, w)
  165. assert.Nil(t, err)
  166. assert.True(t, VerifyProof(*circuit, setup, proof, false))
  167. }