You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
2.9 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/stretchr/testify/assert"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. )
  9. type InOut struct {
  10. inputs []*big.Int
  11. result *big.Int
  12. }
  13. type TraceCorrectnessTest struct {
  14. code string
  15. io []InOut
  16. }
  17. var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
  18. var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
  19. var correctnesTest = []TraceCorrectnessTest{
  20. {
  21. io: []InOut{{
  22. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  23. result: big.NewInt(int64(1729500084900343)),
  24. }, {
  25. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  26. result: bigNumberResult1,
  27. }},
  28. code: `
  29. def do(x):
  30. e = x * 5
  31. b = e * 6
  32. c = b * 7
  33. f = c * 1
  34. d = c * f
  35. out = d * mul(d,e)
  36. def add(x ,k):
  37. z = k * x
  38. out = do(x) + mul(x,z)
  39. def main(x,z):
  40. out = do(z) + add(x,x)
  41. def mul(a,b):
  42. out = a * b
  43. `,
  44. },
  45. {io: []InOut{{
  46. inputs: []*big.Int{big.NewInt(int64(7))},
  47. result: big.NewInt(int64(4)),
  48. }},
  49. code: `
  50. def mul(a,b):
  51. out = a * b
  52. def main(a):
  53. b = a * a
  54. c = 4 - b
  55. d = 5 * c
  56. out = mul(d,c) / mul(b,b)
  57. `,
  58. },
  59. {io: []InOut{{
  60. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  61. result: big.NewInt(int64(22638)),
  62. }, {
  63. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  64. result: bigNumberResult2,
  65. }},
  66. code: `
  67. def main(a,b):
  68. d = b + b
  69. c = a * d
  70. e = c - a
  71. out = e * c
  72. `,
  73. },
  74. {
  75. io: []InOut{{
  76. inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
  77. result: big.NewInt(int64(98441327276)),
  78. }, {
  79. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  80. result: big.NewInt(int64(8675445947220)),
  81. }},
  82. code: `
  83. def main(a,b):
  84. c = a + b
  85. e = c - a
  86. f = e + b
  87. g = f + 2
  88. out = g * a
  89. `,
  90. },
  91. }
  92. func TestNewProgramm(t *testing.T) {
  93. flat := `
  94. def doSomething(x ,k):
  95. z = k * x
  96. out = 6 + mul(x,z)
  97. def main(x,z):
  98. out = mul(x,z) - doSomething(x,x)
  99. def mul(a,b):
  100. out = a * b
  101. `
  102. for _, test := range correctnesTest {
  103. parser := NewParser(strings.NewReader(test.code))
  104. program, err := parser.Parse()
  105. if err != nil {
  106. panic(err)
  107. }
  108. fmt.Println("\n unreduced")
  109. fmt.Println(test.code)
  110. program.BuildConstraintTrees()
  111. for k, v := range program.functions {
  112. fmt.Println(k)
  113. PrintTree(v.root)
  114. }
  115. fmt.Println("\nReduced gates")
  116. //PrintTree(froots["mul"])
  117. gates := program.ReduceCombinedTree()
  118. for _, g := range gates {
  119. fmt.Printf("\n %v", g)
  120. }
  121. fmt.Println("\n generating R1CS")
  122. a, b, c := program.GenerateReducedR1CS(gates)
  123. fmt.Println(a)
  124. fmt.Println(b)
  125. fmt.Println(c)
  126. for _, io := range test.io {
  127. inputs := io.inputs
  128. fmt.Println("input")
  129. fmt.Println(inputs)
  130. w := program.CalculateWitness(inputs)
  131. fmt.Println("witness")
  132. fmt.Println(w)
  133. assert.Equal(t, io.result, w[len(w)-1])
  134. }
  135. }
  136. }