You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

315 lines
5.5 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/stretchr/testify/assert"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. )
  9. type InOut struct {
  10. inputs []*big.Int
  11. result *big.Int
  12. }
  13. type TraceCorrectnessTest struct {
  14. code string
  15. io []InOut
  16. }
  17. var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
  18. var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
  19. var correctnesTest = []TraceCorrectnessTest{
  20. {
  21. io: []InOut{{
  22. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  23. result: big.NewInt(int64(1729500084900343)),
  24. }, {
  25. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  26. result: bigNumberResult1,
  27. }},
  28. code: `
  29. def main( x , z ) :
  30. out = do(z) + add(x,x)
  31. def do(x):
  32. e = x * 5
  33. b = e * 6
  34. c = b * 7
  35. f = c * 1
  36. d = c * f
  37. out = d * mul(d,e)
  38. def add(x ,k):
  39. z = k * x
  40. out = do(x) + mul(x,z)
  41. def mul(a,b):
  42. out = a * b
  43. `,
  44. },
  45. {io: []InOut{{
  46. inputs: []*big.Int{big.NewInt(int64(7))},
  47. result: big.NewInt(int64(4)),
  48. }},
  49. code: `
  50. def mul(a,b):
  51. out = a * b
  52. def main(a):
  53. b = a * a
  54. c = 4 - b
  55. d = 5 * c
  56. out = mul(d,c) / mul(b,b)
  57. `,
  58. },
  59. {io: []InOut{{
  60. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  61. result: big.NewInt(int64(22638)),
  62. }, {
  63. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  64. result: bigNumberResult2,
  65. }},
  66. code: `
  67. def main(a,b):
  68. d = b + b
  69. c = a * d
  70. e = c - a
  71. out = e * c
  72. `,
  73. },
  74. {
  75. io: []InOut{{
  76. inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
  77. result: big.NewInt(int64(98441327276)),
  78. }, {
  79. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  80. result: big.NewInt(int64(8675445947220)),
  81. }},
  82. code: `
  83. def main(a,b):
  84. c = a + b
  85. e = c - a
  86. f = e + b
  87. g = f + 2
  88. out = g * a
  89. `,
  90. },
  91. {
  92. io: []InOut{{
  93. inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
  94. result: big.NewInt(int64(444675)),
  95. }},
  96. code: `
  97. def main(a,b,c,d):
  98. e = a * b
  99. f = c * d
  100. g = e * f
  101. h = g / e
  102. i = h * 5
  103. out = g * i
  104. `,
  105. },
  106. {
  107. io: []InOut{{
  108. inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
  109. result: big.NewInt(int64(264)),
  110. }},
  111. code: `
  112. def main(a,b,c,d):
  113. e = a * 3
  114. f = b * 7
  115. g = c * 11
  116. h = d * 13
  117. i = e + f
  118. j = g + h
  119. k = i + j
  120. out = k * 1
  121. `,
  122. },
  123. }
  124. func TestCorrectness(t *testing.T) {
  125. for _, test := range correctnesTest {
  126. parser := NewParser(strings.NewReader(test.code))
  127. program, err := parser.Parse()
  128. if err != nil {
  129. panic(err)
  130. }
  131. fmt.Println("\n unreduced")
  132. fmt.Println(test.code)
  133. program.BuildConstraintTrees()
  134. for k, v := range program.functions {
  135. fmt.Println(k)
  136. PrintTree(v.root)
  137. }
  138. fmt.Println("\nReduced gates")
  139. //PrintTree(froots["mul"])
  140. gates := program.ReduceCombinedTree()
  141. for _, g := range gates {
  142. fmt.Printf("\n %v", g)
  143. }
  144. fmt.Println("\n generating R1CS")
  145. r1cs := program.GenerateReducedR1CS(gates)
  146. fmt.Println(r1cs.A)
  147. fmt.Println(r1cs.B)
  148. fmt.Println(r1cs.C)
  149. for _, io := range test.io {
  150. inputs := io.inputs
  151. fmt.Println("input")
  152. fmt.Println(inputs)
  153. w := CalculateWitness(inputs, r1cs)
  154. fmt.Println("witness")
  155. fmt.Println(w)
  156. assert.Equal(t, io.result, w[program.GlobalInputCount()])
  157. }
  158. }
  159. }
  160. //test to check gate optimisation
  161. //mess around the code s.t. results is unchanged. number of gates should remain the same in any case
  162. func TestGateOptimisation(t *testing.T) {
  163. io := InOut{
  164. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  165. result: bigNumberResult1,
  166. }
  167. equalCodes := []string{
  168. `
  169. def main( x , z ) :
  170. out = do(z) + add(x,x)
  171. def do(x):
  172. e = x * 5
  173. b = e * 6
  174. c = 7 * b
  175. f = c * 1
  176. d = f * c
  177. out = d * mul(d,e)
  178. def add(x ,k):
  179. z = k * x
  180. out = do(x) + mul(x,z)
  181. def mul(a,b):
  182. out = b * a
  183. `, //switching order
  184. `
  185. def main( x , z ) :
  186. out = do(z) + add(x,x)
  187. def do(x):
  188. e = x * 5
  189. b = e * 6
  190. c = b * 7
  191. f = c * 1
  192. d = c * f
  193. out = d * mul(d,e)
  194. def add(x ,k):
  195. z = k * x
  196. out = do(x) + mul(x,z)
  197. def mul(a,b):
  198. out = a * b
  199. `, //switching order
  200. `
  201. def main( x , z ) :
  202. out = do(z) + add(x,x)
  203. def do(x):
  204. e = x * 5
  205. j = e * 3
  206. k = e * 3
  207. b = j+k
  208. c = b * 7
  209. f = c * 1
  210. d = c * f
  211. g = d * 1
  212. out = g * mul(d,e)
  213. def add(k ,x):
  214. z = k * x
  215. out = do(x) + mul(x,z)
  216. def mul(b,a):
  217. out = a * b
  218. `, `
  219. def main( x , z ) :
  220. out = add(x,x)+do(z)
  221. def do(x):
  222. e = x * 5
  223. j = 3 * e
  224. k = e * 3
  225. b = j+k
  226. c = b * 7
  227. f = c * 1
  228. d = c * f
  229. g = d * 1
  230. out = mul(d,e) * g
  231. def add(k ,x):
  232. z = k * x
  233. out = mul(x,z) + do(x)
  234. def mul(b,a):
  235. out = a * b
  236. `,
  237. }
  238. var r1css = make([]R1CS, len(equalCodes))
  239. for i, c := range equalCodes {
  240. parser := NewParser(strings.NewReader(c))
  241. program, err := parser.Parse()
  242. if err != nil {
  243. panic(err)
  244. }
  245. program.BuildConstraintTrees()
  246. gates := program.ReduceCombinedTree()
  247. for _, g := range gates {
  248. fmt.Printf("\n %v", g)
  249. }
  250. fmt.Println("\n generating R1CS")
  251. r1cs := program.GenerateReducedR1CS(gates)
  252. r1css[i] = r1cs
  253. fmt.Println(r1cs.A)
  254. fmt.Println(r1cs.B)
  255. fmt.Println(r1cs.C)
  256. }
  257. for i := 0; i < len(equalCodes)-1; i++ {
  258. assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A))
  259. }
  260. for i := 0; i < len(equalCodes); i++ {
  261. //assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A))
  262. w := CalculateWitness(io.inputs, r1css[i])
  263. fmt.Println("witness")
  264. fmt.Println(w)
  265. assert.Equal(t, io.result, w[3])
  266. }
  267. }