You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
3.0 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/stretchr/testify/assert"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. )
  9. type InOut struct {
  10. inputs []*big.Int
  11. result *big.Int
  12. }
  13. type TraceCorrectnessTest struct {
  14. code string
  15. io []InOut
  16. }
  17. var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
  18. var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
  19. var correctnesTest = []TraceCorrectnessTest{
  20. {
  21. io: []InOut{{
  22. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  23. result: big.NewInt(int64(1729500084900343)),
  24. }, {
  25. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  26. result: bigNumberResult1,
  27. }},
  28. code: `
  29. def main( x , z ) :
  30. out = do(z) + add(x,x)
  31. def do(x):
  32. e = x * 5
  33. b = e * 6
  34. c = b * 7
  35. f = c * 1
  36. d = c * f
  37. out = d * mul(d,e)
  38. def add(x ,k):
  39. z = k * x
  40. out = do(x) + mul(x,z)
  41. def mul(a,b):
  42. out = a * b
  43. `,
  44. },
  45. {io: []InOut{{
  46. inputs: []*big.Int{big.NewInt(int64(7))},
  47. result: big.NewInt(int64(4)),
  48. }},
  49. code: `
  50. def mul(a,b):
  51. out = a * b
  52. def main(a):
  53. b = a * a
  54. c = 4 - b
  55. d = 5 * c
  56. out = mul(d,c) / mul(b,b)
  57. `,
  58. },
  59. {io: []InOut{{
  60. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  61. result: big.NewInt(int64(22638)),
  62. }, {
  63. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  64. result: bigNumberResult2,
  65. }},
  66. code: `
  67. def main(a,b):
  68. d = b + b
  69. c = a * d
  70. e = c - a
  71. out = e * c
  72. `,
  73. },
  74. {
  75. io: []InOut{{
  76. inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
  77. result: big.NewInt(int64(98441327276)),
  78. }, {
  79. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  80. result: big.NewInt(int64(8675445947220)),
  81. }},
  82. code: `
  83. def main(a,b):
  84. c = a + b
  85. e = c - a
  86. f = e + b
  87. g = f + 2
  88. out = g * a
  89. `,
  90. },
  91. {
  92. io: []InOut{{
  93. inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
  94. result: big.NewInt(int64(444675)),
  95. }},
  96. code: `
  97. def main(a,b,c,d):
  98. e = a * b
  99. f = c * d
  100. g = e * f
  101. h = g / e
  102. i = h * 5
  103. out = g * i
  104. `,
  105. },
  106. }
  107. func TestNewProgramm(t *testing.T) {
  108. for _, test := range correctnesTest {
  109. parser := NewParser(strings.NewReader(test.code))
  110. program, err := parser.Parse()
  111. if err != nil {
  112. panic(err)
  113. }
  114. fmt.Println("\n unreduced")
  115. fmt.Println(test.code)
  116. program.BuildConstraintTrees()
  117. for k, v := range program.functions {
  118. fmt.Println(k)
  119. PrintTree(v.root)
  120. }
  121. fmt.Println("\nReduced gates")
  122. //PrintTree(froots["mul"])
  123. gates := program.ReduceCombinedTree()
  124. for _, g := range gates {
  125. fmt.Printf("\n %v", g)
  126. }
  127. fmt.Println("\n generating R1CS")
  128. r1cs := program.GenerateReducedR1CS(gates)
  129. fmt.Println(r1cs.A)
  130. fmt.Println(r1cs.B)
  131. fmt.Println(r1cs.C)
  132. for _, io := range test.io {
  133. inputs := io.inputs
  134. fmt.Println("input")
  135. fmt.Println(inputs)
  136. w := CalculateWitness(inputs, r1cs)
  137. fmt.Println("witness")
  138. fmt.Println(w)
  139. assert.Equal(t, io.result, w[len(program.globalInputs)-1])
  140. }
  141. }
  142. }