You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
3.9 KiB

  1. package snark
  2. import (
  3. "fmt"
  4. "math/big"
  5. "strings"
  6. "testing"
  7. "github.com/arnaucube/go-snark/bn128"
  8. "github.com/arnaucube/go-snark/circuitcompiler"
  9. "github.com/arnaucube/go-snark/fields"
  10. "github.com/arnaucube/go-snark/r1csqap"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. func TestZkFromHardcodedR1CS(t *testing.T) {
  14. bn, err := bn128.NewBn128()
  15. assert.Nil(t, err)
  16. // new Finite Field
  17. fqR := fields.NewFq(bn.R)
  18. // new Polynomial Field
  19. pf := r1csqap.NewPolynomialField(fqR)
  20. b0 := big.NewInt(int64(0))
  21. b1 := big.NewInt(int64(1))
  22. b3 := big.NewInt(int64(3))
  23. b5 := big.NewInt(int64(5))
  24. b9 := big.NewInt(int64(9))
  25. b27 := big.NewInt(int64(27))
  26. b30 := big.NewInt(int64(30))
  27. b35 := big.NewInt(int64(35))
  28. a := [][]*big.Int{
  29. []*big.Int{b0, b1, b0, b0, b0, b0},
  30. []*big.Int{b0, b0, b0, b1, b0, b0},
  31. []*big.Int{b0, b1, b0, b0, b1, b0},
  32. []*big.Int{b5, b0, b0, b0, b0, b1},
  33. }
  34. b := [][]*big.Int{
  35. []*big.Int{b0, b1, b0, b0, b0, b0},
  36. []*big.Int{b0, b1, b0, b0, b0, b0},
  37. []*big.Int{b1, b0, b0, b0, b0, b0},
  38. []*big.Int{b1, b0, b0, b0, b0, b0},
  39. }
  40. c := [][]*big.Int{
  41. []*big.Int{b0, b0, b0, b1, b0, b0},
  42. []*big.Int{b0, b0, b0, b0, b1, b0},
  43. []*big.Int{b0, b0, b0, b0, b0, b1},
  44. []*big.Int{b0, b0, b1, b0, b0, b0},
  45. }
  46. alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
  47. // wittness = 1, 3, 35, 9, 27, 30
  48. w := []*big.Int{b1, b3, b35, b9, b27, b30}
  49. circuit := circuitcompiler.Circuit{
  50. NVars: 6,
  51. NPublic: 0,
  52. NSignals: len(w),
  53. }
  54. ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
  55. hx := pf.DivisorPolinomial(px, zx)
  56. // hx==px/zx so px==hx*zx
  57. assert.Equal(t, px, pf.Mul(hx, zx))
  58. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  59. abc := pf.Sub(pf.Mul(ax, bx), cx)
  60. assert.Equal(t, abc, px)
  61. hz := pf.Mul(hx, zx)
  62. assert.Equal(t, abc, hz)
  63. div, rem := pf.Div(px, zx)
  64. assert.Equal(t, hx, div)
  65. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  66. // calculate trusted setup
  67. setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), circuit, alphas, betas, gammas, zx)
  68. assert.Nil(t, err)
  69. fmt.Println("t", setup.Toxic.T)
  70. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  71. proof, err := GenerateProofs(bn, fqR, circuit, setup, hx, w)
  72. assert.Nil(t, err)
  73. assert.True(t, VerifyProof(bn, circuit, setup, proof))
  74. }
  75. func TestZkFromFlatCircuitCode(t *testing.T) {
  76. bn, err := bn128.NewBn128()
  77. assert.Nil(t, err)
  78. // new Finite Field
  79. fqR := fields.NewFq(bn.R)
  80. // new Polynomial Field
  81. pf := r1csqap.NewPolynomialField(fqR)
  82. // compile circuit and get the R1CS
  83. flatCode := `
  84. func test(x):
  85. aux = x*x
  86. y = aux*x
  87. z = x + y
  88. out = z + 5
  89. `
  90. // parse the code
  91. parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
  92. circuit, err := parser.Parse()
  93. assert.Nil(t, err)
  94. fmt.Println(circuit)
  95. // flat code to R1CS
  96. fmt.Println("generating R1CS from flat code")
  97. a, b, c := circuit.GenerateR1CS()
  98. alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
  99. // wittness = 1, 3, 35, 9, 27, 30
  100. b1 := big.NewInt(int64(1))
  101. b3 := big.NewInt(int64(3))
  102. b9 := big.NewInt(int64(9))
  103. b27 := big.NewInt(int64(27))
  104. b30 := big.NewInt(int64(30))
  105. b35 := big.NewInt(int64(35))
  106. w := []*big.Int{b1, b3, b35, b9, b27, b30}
  107. ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
  108. hx := pf.DivisorPolinomial(px, zx)
  109. // hx==px/zx so px==hx*zx
  110. assert.Equal(t, px, pf.Mul(hx, zx))
  111. // p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
  112. abc := pf.Sub(pf.Mul(ax, bx), cx)
  113. assert.Equal(t, abc, px)
  114. hz := pf.Mul(hx, zx)
  115. assert.Equal(t, abc, hz)
  116. div, rem := pf.Div(px, zx)
  117. assert.Equal(t, hx, div)
  118. assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
  119. // calculate trusted setup
  120. setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), *circuit, alphas, betas, gammas, zx)
  121. assert.Nil(t, err)
  122. fmt.Println("t", setup.Toxic.T)
  123. // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
  124. proof, err := GenerateProofs(bn, fqR, *circuit, setup, hx, w)
  125. assert.Nil(t, err)
  126. assert.True(t, VerifyProof(bn, *circuit, setup, proof))
  127. }