You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
7.0 KiB

5 years ago
  1. package circuit
  2. import (
  3. "bufio"
  4. "math/big"
  5. "os"
  6. "strings"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestCircuitParser(t *testing.T) {
  11. // y = x^3 + x + 5
  12. flat := `
  13. func main(private s0, public s1):
  14. s2 = s0 * s0
  15. s3 = s2 * s0
  16. s4 = s3 + s0
  17. s5 = s4 + 5
  18. equals(s1, s5)
  19. out = 1 * 1
  20. `
  21. parser := NewParser(strings.NewReader(flat))
  22. cir, err := parser.Parse()
  23. assert.Nil(t, err)
  24. // flat code to R1CS
  25. cir.GenerateR1CS()
  26. assert.Equal(t, "s0", cir.PrivateInputs[0])
  27. assert.Equal(t, "s1", cir.PublicInputs[0])
  28. assert.Equal(t, []string{"one", "s1", "s0", "s2", "s3", "s4", "s5", "out"}, cir.Signals)
  29. // expected result
  30. b0 := big.NewInt(int64(0))
  31. b1 := big.NewInt(int64(1))
  32. b5 := big.NewInt(int64(5))
  33. aExpected := [][]*big.Int{
  34. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  35. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  36. []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
  37. []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
  38. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  39. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  40. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  41. }
  42. bExpected := [][]*big.Int{
  43. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  44. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  45. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  46. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  47. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  48. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  49. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  50. }
  51. cExpected := [][]*big.Int{
  52. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  53. []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
  54. []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
  55. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  56. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  57. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  58. []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
  59. }
  60. assert.Equal(t, aExpected, cir.R1CS.A)
  61. assert.Equal(t, bExpected, cir.R1CS.B)
  62. assert.Equal(t, cExpected, cir.R1CS.C)
  63. b3 := big.NewInt(int64(3))
  64. privateInputs := []*big.Int{b3}
  65. b35 := big.NewInt(int64(35))
  66. publicInputs := []*big.Int{b35}
  67. // Calculate Witness
  68. w, err := cir.CalculateWitness(privateInputs, publicInputs)
  69. assert.Nil(t, err)
  70. b9 := big.NewInt(int64(9))
  71. b27 := big.NewInt(int64(27))
  72. b30 := big.NewInt(int64(30))
  73. wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
  74. assert.Equal(t, wExpected, w)
  75. assert.Equal(t, cir.NPublic, 1)
  76. assert.Equal(t, len(cir.PublicInputs), 1)
  77. assert.Equal(t, len(cir.PrivateInputs), 1)
  78. }
  79. func TestCircuitWithFuncCallsParser(t *testing.T) {
  80. // y = x^3 + x + 5
  81. code := `
  82. func exp3(private a):
  83. b = a * a
  84. c = a * b
  85. return c
  86. func sum(private a, private b):
  87. c = a + b
  88. return c
  89. func main(private s0, public s1):
  90. s3 = exp3(s0)
  91. s4 = sum(s3, s0)
  92. s5 = s4 + 5
  93. equals(s1, s5)
  94. out = 1 * 1
  95. `
  96. parser := NewParser(strings.NewReader(code))
  97. cir, err := parser.Parse()
  98. assert.Nil(t, err)
  99. // flat code to R1CS
  100. cir.GenerateR1CS()
  101. assert.Equal(t, "s0", cir.PrivateInputs[0])
  102. assert.Equal(t, "s1", cir.PublicInputs[0])
  103. assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, cir.Signals)
  104. // expected result
  105. b0 := big.NewInt(int64(0))
  106. b1 := big.NewInt(int64(1))
  107. b5 := big.NewInt(int64(5))
  108. aExpected := [][]*big.Int{
  109. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  110. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  111. []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
  112. []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
  113. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  114. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  115. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  116. }
  117. bExpected := [][]*big.Int{
  118. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  119. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  120. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  121. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  122. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  123. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  124. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  125. }
  126. cExpected := [][]*big.Int{
  127. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  128. []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
  129. []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
  130. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  131. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  132. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  133. []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
  134. }
  135. assert.Equal(t, aExpected, cir.R1CS.A)
  136. assert.Equal(t, bExpected, cir.R1CS.B)
  137. assert.Equal(t, cExpected, cir.R1CS.C)
  138. b3 := big.NewInt(int64(3))
  139. privateInputs := []*big.Int{b3}
  140. b35 := big.NewInt(int64(35))
  141. publicInputs := []*big.Int{b35}
  142. w, err := cir.CalculateWitness(privateInputs, publicInputs)
  143. assert.Nil(t, err)
  144. b9 := big.NewInt(int64(9))
  145. b27 := big.NewInt(int64(27))
  146. b30 := big.NewInt(int64(30))
  147. wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
  148. assert.Equal(t, wExpected, w)
  149. assert.Equal(t, cir.NPublic, 1)
  150. assert.Equal(t, len(cir.PublicInputs), 1)
  151. assert.Equal(t, len(cir.PrivateInputs), 1)
  152. }
  153. func TestCircuitFromFileWithImports(t *testing.T) {
  154. circuitFile, err := os.Open("./circuit-test-1.circuit")
  155. assert.Nil(t, err)
  156. parser := NewParser(bufio.NewReader(circuitFile))
  157. cir, err := parser.Parse()
  158. assert.Nil(t, err)
  159. // flat code to R1CS
  160. cir.GenerateR1CS()
  161. assert.Equal(t, "s0", cir.PrivateInputs[0])
  162. assert.Equal(t, "s1", cir.PublicInputs[0])
  163. assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, cir.Signals)
  164. // expected result
  165. b0 := big.NewInt(int64(0))
  166. b1 := big.NewInt(int64(1))
  167. b5 := big.NewInt(int64(5))
  168. aExpected := [][]*big.Int{
  169. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  170. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  171. []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
  172. []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
  173. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  174. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  175. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  176. }
  177. bExpected := [][]*big.Int{
  178. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  179. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  180. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  181. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  182. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  183. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  184. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  185. }
  186. cExpected := [][]*big.Int{
  187. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  188. []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
  189. []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
  190. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  191. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  192. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  193. []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
  194. }
  195. assert.Equal(t, aExpected, cir.R1CS.A)
  196. assert.Equal(t, bExpected, cir.R1CS.B)
  197. assert.Equal(t, cExpected, cir.R1CS.C)
  198. b3 := big.NewInt(int64(3))
  199. privateInputs := []*big.Int{b3}
  200. b35 := big.NewInt(int64(35))
  201. publicInputs := []*big.Int{b35}
  202. // Calculate Witness
  203. w, err := cir.CalculateWitness(privateInputs, publicInputs)
  204. assert.Nil(t, err)
  205. b9 := big.NewInt(int64(9))
  206. b27 := big.NewInt(int64(27))
  207. b30 := big.NewInt(int64(30))
  208. wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
  209. assert.Equal(t, wExpected, w)
  210. assert.Equal(t, cir.NPublic, 1)
  211. assert.Equal(t, len(cir.PublicInputs), 1)
  212. assert.Equal(t, len(cir.PrivateInputs), 1)
  213. }