You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

248 lines
7.4 KiB

5 years ago
  1. package circuitcompiler
  2. import (
  3. "bufio"
  4. "math/big"
  5. "os"
  6. "strings"
  7. "testing"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestCircuitParser(t *testing.T) {
  11. // y = x^3 + x + 5
  12. flat := `
  13. func main(private s0, public s1):
  14. s2 = s0 * s0
  15. s3 = s2 * s0
  16. s4 = s3 + s0
  17. s5 = s4 + 5
  18. equals(s1, s5)
  19. out = 1 * 1
  20. `
  21. parser := NewParser(strings.NewReader(flat))
  22. circuit, err := parser.Parse()
  23. assert.Nil(t, err)
  24. // flat code to R1CS
  25. a, b, c := circuit.GenerateR1CS()
  26. assert.Equal(t, "s0", circuit.PrivateInputs[0])
  27. assert.Equal(t, "s1", circuit.PublicInputs[0])
  28. assert.Equal(t, []string{"one", "s1", "s0", "s2", "s3", "s4", "s5", "out"}, circuit.Signals)
  29. // expected result
  30. b0 := big.NewInt(int64(0))
  31. b1 := big.NewInt(int64(1))
  32. b5 := big.NewInt(int64(5))
  33. aExpected := [][]*big.Int{
  34. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  35. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  36. []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
  37. []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
  38. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  39. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  40. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  41. }
  42. bExpected := [][]*big.Int{
  43. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  44. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  45. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  46. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  47. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  48. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  49. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  50. }
  51. cExpected := [][]*big.Int{
  52. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  53. []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
  54. []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
  55. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  56. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  57. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  58. []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
  59. }
  60. assert.Equal(t, aExpected, a)
  61. assert.Equal(t, bExpected, b)
  62. assert.Equal(t, cExpected, c)
  63. b3 := big.NewInt(int64(3))
  64. privateInputs := []*big.Int{b3}
  65. b35 := big.NewInt(int64(35))
  66. publicInputs := []*big.Int{b35}
  67. // Calculate Witness
  68. w, err := circuit.CalculateWitness(privateInputs, publicInputs)
  69. assert.Nil(t, err)
  70. b9 := big.NewInt(int64(9))
  71. b27 := big.NewInt(int64(27))
  72. b30 := big.NewInt(int64(30))
  73. wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
  74. assert.Equal(t, wExpected, w)
  75. // circuitJson, _ := json.Marshal(circuit)
  76. // fmt.Println("circuit:", string(circuitJson))
  77. assert.Equal(t, circuit.NPublic, 1)
  78. assert.Equal(t, len(circuit.PublicInputs), 1)
  79. assert.Equal(t, len(circuit.PrivateInputs), 1)
  80. }
  81. func TestCircuitWithFuncCallsParser(t *testing.T) {
  82. // y = x^3 + x + 5
  83. code := `
  84. func exp3(private a):
  85. b = a * a
  86. c = a * b
  87. return c
  88. func sum(private a, private b):
  89. c = a + b
  90. return c
  91. func main(private s0, public s1):
  92. s3 = exp3(s0)
  93. s4 = sum(s3, s0)
  94. s5 = s4 + 5
  95. equals(s1, s5)
  96. out = 1 * 1
  97. `
  98. parser := NewParser(strings.NewReader(code))
  99. circuit, err := parser.Parse()
  100. assert.Nil(t, err)
  101. // flat code to R1CS
  102. a, b, c := circuit.GenerateR1CS()
  103. assert.Equal(t, "s0", circuit.PrivateInputs[0])
  104. assert.Equal(t, "s1", circuit.PublicInputs[0])
  105. assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, circuit.Signals)
  106. // expected result
  107. b0 := big.NewInt(int64(0))
  108. b1 := big.NewInt(int64(1))
  109. b5 := big.NewInt(int64(5))
  110. aExpected := [][]*big.Int{
  111. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  112. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  113. []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
  114. []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
  115. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  116. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  117. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  118. }
  119. bExpected := [][]*big.Int{
  120. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  121. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  122. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  123. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  124. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  125. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  126. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  127. }
  128. cExpected := [][]*big.Int{
  129. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  130. []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
  131. []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
  132. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  133. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  134. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  135. []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
  136. }
  137. assert.Equal(t, aExpected, a)
  138. assert.Equal(t, bExpected, b)
  139. assert.Equal(t, cExpected, c)
  140. b3 := big.NewInt(int64(3))
  141. privateInputs := []*big.Int{b3}
  142. b35 := big.NewInt(int64(35))
  143. publicInputs := []*big.Int{b35}
  144. // Calculate Witness
  145. w, err := circuit.CalculateWitness(privateInputs, publicInputs)
  146. assert.Nil(t, err)
  147. b9 := big.NewInt(int64(9))
  148. b27 := big.NewInt(int64(27))
  149. b30 := big.NewInt(int64(30))
  150. wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
  151. assert.Equal(t, wExpected, w)
  152. // circuitJson, _ := json.Marshal(circuit)
  153. // fmt.Println("circuit:", string(circuitJson))
  154. assert.Equal(t, circuit.NPublic, 1)
  155. assert.Equal(t, len(circuit.PublicInputs), 1)
  156. assert.Equal(t, len(circuit.PrivateInputs), 1)
  157. }
  158. func TestCircuitFromFileWithImports(t *testing.T) {
  159. circuitFile, err := os.Open("./circuit-test-1.circuit")
  160. assert.Nil(t, err)
  161. parser := NewParser(bufio.NewReader(circuitFile))
  162. circuit, err := parser.Parse()
  163. assert.Nil(t, err)
  164. // flat code to R1CS
  165. a, b, c := circuit.GenerateR1CS()
  166. assert.Equal(t, "s0", circuit.PrivateInputs[0])
  167. assert.Equal(t, "s1", circuit.PublicInputs[0])
  168. assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, circuit.Signals)
  169. // expected result
  170. b0 := big.NewInt(int64(0))
  171. b1 := big.NewInt(int64(1))
  172. b5 := big.NewInt(int64(5))
  173. aExpected := [][]*big.Int{
  174. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  175. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  176. []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
  177. []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
  178. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  179. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  180. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  181. }
  182. bExpected := [][]*big.Int{
  183. []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
  184. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  185. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  186. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  187. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  188. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  189. []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
  190. }
  191. cExpected := [][]*big.Int{
  192. []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
  193. []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
  194. []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
  195. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  196. []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
  197. []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
  198. []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
  199. }
  200. assert.Equal(t, aExpected, a)
  201. assert.Equal(t, bExpected, b)
  202. assert.Equal(t, cExpected, c)
  203. b3 := big.NewInt(int64(3))
  204. privateInputs := []*big.Int{b3}
  205. b35 := big.NewInt(int64(35))
  206. publicInputs := []*big.Int{b35}
  207. // Calculate Witness
  208. w, err := circuit.CalculateWitness(privateInputs, publicInputs)
  209. assert.Nil(t, err)
  210. b9 := big.NewInt(int64(9))
  211. b27 := big.NewInt(int64(27))
  212. b30 := big.NewInt(int64(30))
  213. wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
  214. assert.Equal(t, wExpected, w)
  215. // circuitJson, _ := json.Marshal(circuit)
  216. // fmt.Println("circuit:", string(circuitJson))
  217. assert.Equal(t, circuit.NPublic, 1)
  218. assert.Equal(t, len(circuit.PublicInputs), 1)
  219. assert.Equal(t, len(circuit.PrivateInputs), 1)
  220. }