You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
2.9 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/stretchr/testify/assert"
  5. "math/big"
  6. "strings"
  7. "testing"
  8. )
  9. type InOut struct {
  10. inputs []*big.Int
  11. result *big.Int
  12. }
  13. type TraceCorrectnessTest struct {
  14. code string
  15. io []InOut
  16. }
  17. var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10)
  18. var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10)
  19. var correctnesTest = []TraceCorrectnessTest{
  20. {
  21. io: []InOut{{
  22. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  23. result: big.NewInt(int64(1729500084900343)),
  24. }, {
  25. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  26. result: bigNumberResult1,
  27. }},
  28. code: `
  29. def do(x):
  30. e = x * 5
  31. b = e * 6
  32. c = b * 7
  33. f = c * 1
  34. d = c * f
  35. out = d * mul(d,e)
  36. def add(x ,k):
  37. z = k * x
  38. out = do(x) + mul(x,z)
  39. def main(x,z):
  40. out = do(z) + add(x,x)
  41. def mul(a,b):
  42. out = a * b
  43. `,
  44. },
  45. {io: []InOut{{
  46. inputs: []*big.Int{big.NewInt(int64(7))},
  47. result: big.NewInt(int64(4)),
  48. }},
  49. code: `
  50. def mul(a,b):
  51. out = a * b
  52. def main(a):
  53. b = a * a
  54. c = 4 - b
  55. d = 5 * c
  56. out = mul(d,c) / mul(b,b)
  57. `,
  58. },
  59. {io: []InOut{{
  60. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  61. result: big.NewInt(int64(22638)),
  62. }, {
  63. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  64. result: bigNumberResult2,
  65. }},
  66. code: `
  67. def main(a,b):
  68. d = b + b
  69. c = a * d
  70. e = c - a
  71. out = e * c
  72. `,
  73. },
  74. {
  75. io: []InOut{{
  76. inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
  77. result: big.NewInt(int64(98441327276)),
  78. }, {
  79. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  80. result: big.NewInt(int64(8675445947220)),
  81. }},
  82. code: `
  83. def main(a,b):
  84. c = a + b
  85. e = c - a
  86. f = e + b
  87. g = f + 2
  88. out = g * a
  89. `,
  90. },
  91. }
  92. func TestNewProgramm(t *testing.T) {
  93. //flat := `
  94. //
  95. //def doSomething(x ,k):
  96. // z = k * x
  97. // out = 6 + mul(x,z)
  98. //
  99. //def main(x,z):
  100. // out = mul(x,z) - doSomething(x,x)
  101. //
  102. //def mul(a,b):
  103. // out = a * b
  104. //`
  105. for _, test := range correctnesTest {
  106. parser := NewParser(strings.NewReader(test.code))
  107. program, err := parser.Parse()
  108. if err != nil {
  109. panic(err)
  110. }
  111. fmt.Println("\n unreduced")
  112. fmt.Println(test.code)
  113. program.BuildConstraintTrees()
  114. for k, v := range program.functions {
  115. fmt.Println(k)
  116. PrintTree(v.root)
  117. }
  118. fmt.Println("\nReduced gates")
  119. //PrintTree(froots["mul"])
  120. gates := program.ReduceCombinedTree()
  121. for _, g := range gates {
  122. fmt.Printf("\n %v", g)
  123. }
  124. fmt.Println("\n generating R1CS")
  125. r1cs := program.GenerateReducedR1CS(gates)
  126. fmt.Println(r1cs.A)
  127. fmt.Println(r1cs.B)
  128. fmt.Println(r1cs.C)
  129. for _, io := range test.io {
  130. inputs := io.inputs
  131. fmt.Println("input")
  132. fmt.Println(inputs)
  133. w := CalculateWitness(inputs, r1cs)
  134. fmt.Println("witness")
  135. fmt.Println(w)
  136. assert.Equal(t, io.result, w[len(w)-1])
  137. }
  138. }
  139. }