You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
5.2 KiB

5 years ago
5 years ago
5 years ago
  1. package snark
  2. import (
  3. "fmt"
  4. "math/big"
  5. "strings"
  6. "testing"
  7. "time"
  8. "github.com/arnaucube/go-snark/circuitcompiler"
  9. "github.com/arnaucube/go-snark/r1csqap"
  10. "github.com/stretchr/testify/assert"
  11. )
  12. func TestZkFromFlatCircuitCode(t *testing.T) {
  13. // compile circuit and get the R1CS
  14. flatCode := `
  15. func test(x):
  16. aux = x*x
  17. y = aux*x
  18. z = x + y
  19. out = z + 5
  20. `
  21. fmt.Print("\nflat code of the circuit:")
  22. fmt.Println(flatCode)
  23. // parse the code
  24. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  25. circuit, err := parser.Parse()
  26. assert.Nil(t, err)
  27. fmt.Println("\ncircuit data:", circuit)
  28. b3 := big.NewInt(int64(3))
  29. inputs := []*big.Int{b3}
  30. // wittness
  31. w, err := circuit.CalculateWitness(inputs)
  32. assert.Nil(t, err)
  33. fmt.Println("\nwitness", w)
  34. // flat code to R1CS
  35. fmt.Println("\ngenerating R1CS from flat code")
  36. a, b, c := circuit.GenerateR1CS()
  37. fmt.Println("\nR1CS:")
  38. fmt.Println("a:", a)
  39. fmt.Println("b:", b)
  40. fmt.Println("c:", c)
  41. // R1CS to QAP
  42. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  43. fmt.Println("qap")
  44. fmt.Println(alphas)
  45. fmt.Println(betas)
  46. fmt.Println(gammas)
  47. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  48. hx := Utils.PF.DivisorPolynomial(px, zx)
  49. // hx==px/zx so px==hx*zx
  50. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  51. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  52. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  53. assert.Equal(t, abc, px)
  54. hz := Utils.PF.Mul(hx, zx)
  55. assert.Equal(t, abc, hz)
  56. div, rem := Utils.PF.Div(px, zx)
  57. assert.Equal(t, hx, div)
  58. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  59. // calculate trusted setup
  60. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  61. assert.Nil(t, err)
  62. fmt.Println("\nt:", setup.Toxic.T)
  63. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  64. proof, err := GenerateProofs(*circuit, setup, hx, w)
  65. assert.Nil(t, err)
  66. fmt.Println("\n proofs:")
  67. fmt.Println(proof)
  68. before := time.Now()
  69. assert.True(t, VerifyProof(*circuit, setup, proof, false))
  70. fmt.Println("verify proof time elapsed:", time.Since(before))
  71. }
  72. func TestZkFromHardcodedR1CS(t *testing.T) {
  73. b0 := big.NewInt(int64(0))
  74. b1 := big.NewInt(int64(1))
  75. b3 := big.NewInt(int64(3))
  76. b5 := big.NewInt(int64(5))
  77. b9 := big.NewInt(int64(9))
  78. b27 := big.NewInt(int64(27))
  79. b30 := big.NewInt(int64(30))
  80. b35 := big.NewInt(int64(35))
  81. a := [][]*big.Int{
  82. []*big.Int{b0, b1, b0, b0, b0, b0},
  83. []*big.Int{b0, b0, b0, b1, b0, b0},
  84. []*big.Int{b0, b1, b0, b0, b1, b0},
  85. []*big.Int{b5, b0, b0, b0, b0, b1},
  86. }
  87. b := [][]*big.Int{
  88. []*big.Int{b0, b1, b0, b0, b0, b0},
  89. []*big.Int{b0, b1, b0, b0, b0, b0},
  90. []*big.Int{b1, b0, b0, b0, b0, b0},
  91. []*big.Int{b1, b0, b0, b0, b0, b0},
  92. }
  93. c := [][]*big.Int{
  94. []*big.Int{b0, b0, b0, b1, b0, b0},
  95. []*big.Int{b0, b0, b0, b0, b1, b0},
  96. []*big.Int{b0, b0, b0, b0, b0, b1},
  97. []*big.Int{b0, b0, b1, b0, b0, b0},
  98. }
  99. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  100. // wittness = 1, 3, 35, 9, 27, 30
  101. w := []*big.Int{b1, b3, b35, b9, b27, b30}
  102. circuit := circuitcompiler.Circuit{
  103. NVars: 6,
  104. NPublic: 0,
  105. NSignals: len(w),
  106. }
  107. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  108. hx := Utils.PF.DivisorPolynomial(px, zx)
  109. // hx==px/zx so px==hx*zx
  110. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  111. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  112. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  113. assert.Equal(t, abc, px)
  114. hz := Utils.PF.Mul(hx, zx)
  115. assert.Equal(t, abc, hz)
  116. div, rem := Utils.PF.Div(px, zx)
  117. assert.Equal(t, hx, div)
  118. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  119. // calculate trusted setup
  120. setup, err := GenerateTrustedSetup(len(w), circuit, alphas, betas, gammas, zx)
  121. assert.Nil(t, err)
  122. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  123. proof, err := GenerateProofs(circuit, setup, hx, w)
  124. assert.Nil(t, err)
  125. assert.True(t, VerifyProof(circuit, setup, proof, true))
  126. }
  127. func TestZkMultiplication(t *testing.T) {
  128. // compile circuit and get the R1CS
  129. flatCode := `
  130. func test(a, b):
  131. out = a * b
  132. `
  133. // parse the code
  134. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  135. circuit, err := parser.Parse()
  136. assert.Nil(t, err)
  137. b3 := big.NewInt(int64(3))
  138. b4 := big.NewInt(int64(4))
  139. inputs := []*big.Int{b3, b4}
  140. // wittness
  141. w, err := circuit.CalculateWitness(inputs)
  142. assert.Nil(t, err)
  143. // flat code to R1CS
  144. a, b, c := circuit.GenerateR1CS()
  145. // R1CS to QAP
  146. alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c)
  147. ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas)
  148. hx := Utils.PF.DivisorPolynomial(px, zx)
  149. // hx==px/zx so px==hx*zx
  150. assert.Equal(t, px, Utils.PF.Mul(hx, zx))
  151. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  152. abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx)
  153. assert.Equal(t, abc, px)
  154. hz := Utils.PF.Mul(hx, zx)
  155. assert.Equal(t, abc, hz)
  156. div, rem := Utils.PF.Div(px, zx)
  157. assert.Equal(t, hx, div)
  158. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1))
  159. // calculate trusted setup
  160. setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx)
  161. assert.Nil(t, err)
  162. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  163. proof, err := GenerateProofs(*circuit, setup, hx, w)
  164. assert.Nil(t, err)
  165. assert.True(t, VerifyProof(*circuit, setup, proof, false))
  166. }