You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
5.1 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/stretchr/testify/assert"
  5. "math/big"
  6. "math/rand"
  7. "strings"
  8. "testing"
  9. )
  10. //factors are essential to identify, if a specific gate has been computed already
  11. //eg. if we can extract a factor from a gate that is independent of commutativity, multiplicativitz we will do much better, in finding and reusing old outputs do
  12. //minimize the multiplication gate number
  13. // for example the gate a*b == gate b*a hence, we only need to compute one of both.
  14. func TestFactorSignature(t *testing.T) {
  15. facNeutral := factors{&factor{multiplicative: [2]int{1, 1}}}
  16. //dont let the random number be to big, cuz of overflow
  17. r1, r2 := rand.Intn(1<<16), rand.Intn(1<<16)
  18. fmt.Println(r1, r2)
  19. equalityGroups := [][]factors{
  20. []factors{ //test sign and gcd
  21. {&factor{multiplicative: [2]int{r1 * 2, -r2 * 2}}},
  22. {&factor{multiplicative: [2]int{-r1, r2}}},
  23. {&factor{multiplicative: [2]int{r1, -r2}}},
  24. {&factor{multiplicative: [2]int{r1 * 3, -r2 * 3}}},
  25. {&factor{multiplicative: [2]int{r1 * r1, -r2 * r1}}},
  26. {&factor{multiplicative: [2]int{r1 * r2, -r2 * r2}}},
  27. }, []factors{ //test kommutativity
  28. {&factor{multiplicative: [2]int{r1, -r2}}, &factor{multiplicative: [2]int{13, 27}}},
  29. {&factor{multiplicative: [2]int{13, 27}}, &factor{multiplicative: [2]int{-r1, r2}}},
  30. },
  31. }
  32. for _, equalityGroup := range equalityGroups {
  33. for i := 0; i < len(equalityGroup)-1; i++ {
  34. sig1, _, _ := factorsSignature(facNeutral, equalityGroup[i])
  35. sig2, _, _ := factorsSignature(facNeutral, equalityGroup[i+1])
  36. assert.Equal(t, sig1, sig2)
  37. sig1, _, _ = factorsSignature(equalityGroup[i], facNeutral)
  38. sig2, _, _ = factorsSignature(facNeutral, equalityGroup[i+1])
  39. assert.Equal(t, sig1, sig2)
  40. sig1, _, _ = factorsSignature(facNeutral, equalityGroup[i])
  41. sig2, _, _ = factorsSignature(equalityGroup[i+1], facNeutral)
  42. assert.Equal(t, sig1, sig2)
  43. sig1, _, _ = factorsSignature(equalityGroup[i], facNeutral)
  44. sig2, _, _ = factorsSignature(equalityGroup[i+1], facNeutral)
  45. assert.Equal(t, sig1, sig2)
  46. }
  47. }
  48. }
  49. func TestGate_ExtractValues(t *testing.T) {
  50. facNeutral := factors{&factor{multiplicative: [2]int{8, 7}}, &factor{multiplicative: [2]int{9, 3}}}
  51. facNeutral2 := factors{&factor{multiplicative: [2]int{9, 1}}, &factor{multiplicative: [2]int{13, 7}}}
  52. fmt.Println(factorsSignature(facNeutral, facNeutral2))
  53. f, fc := extractFactor(facNeutral)
  54. fmt.Println(f)
  55. fmt.Println(fc)
  56. f2, _ := extractFactor(facNeutral2)
  57. fmt.Println(f)
  58. fmt.Println(fc)
  59. fmt.Println(factorsSignature(facNeutral, facNeutral2))
  60. fmt.Println(factorsSignature(f, f2))
  61. }
  62. func TestGCD(t *testing.T) {
  63. fmt.Println(LCM(10, 15))
  64. fmt.Println(LCM(10, 15, 20))
  65. fmt.Println(LCM(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
  66. }
  67. var correctnesTest2 = []TraceCorrectnessTest{
  68. {
  69. io: []InOut{{
  70. inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
  71. result: big.NewInt(int64(98441327276)),
  72. }, {
  73. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  74. result: big.NewInt(int64(8675445947220)),
  75. }},
  76. code: `
  77. def main(a,b):
  78. c = a + b
  79. e = c - a
  80. f = e + b
  81. g = f + 2
  82. out = g * a
  83. `,
  84. },
  85. {io: []InOut{{
  86. inputs: []*big.Int{big.NewInt(int64(7))},
  87. result: big.NewInt(int64(4)),
  88. }},
  89. code: `
  90. def mul(a,b):
  91. out = a * b
  92. def main(a):
  93. b = a * a
  94. c = 4 - b
  95. d = 5 * c
  96. out = mul(d,c) / mul(b,b)
  97. `,
  98. },
  99. {io: []InOut{{
  100. inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
  101. result: big.NewInt(int64(22638)),
  102. }, {
  103. inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
  104. result: bigNumberResult2,
  105. }},
  106. code: `
  107. def main(a,b):
  108. d = b + b
  109. c = a * d
  110. e = c - a
  111. out = e * c
  112. `,
  113. },
  114. {
  115. io: []InOut{{
  116. inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
  117. result: big.NewInt(int64(444675)),
  118. }},
  119. code: `
  120. def main(a,b,c,d):
  121. e = a * b
  122. f = c * d
  123. g = e * f
  124. h = g / e
  125. i = h * 5
  126. out = g * i
  127. `,
  128. },
  129. {
  130. io: []InOut{{
  131. inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
  132. result: big.NewInt(int64(264)),
  133. }},
  134. code: `
  135. def main(a,b,c,d):
  136. e = a * 3
  137. f = b * 7
  138. g = c * 11
  139. h = d * 13
  140. i = e + f
  141. j = g + h
  142. k = i + j
  143. out = k * 1
  144. `,
  145. },
  146. }
  147. func TestCorrectness2(t *testing.T) {
  148. for _, test := range correctnesTest2 {
  149. parser := NewParser(strings.NewReader(test.code))
  150. program, err := parser.Parse()
  151. if err != nil {
  152. panic(err)
  153. }
  154. fmt.Println("\n unreduced")
  155. fmt.Println(test.code)
  156. program.BuildConstraintTrees()
  157. for k, v := range program.functions {
  158. fmt.Println(k)
  159. PrintTree(v.root)
  160. }
  161. fmt.Println("\nReduced gates")
  162. //PrintTree(froots["mul"])
  163. gates := program.ReduceCombinedTree()
  164. for _, g := range gates {
  165. fmt.Printf("\n %v", g)
  166. }
  167. fmt.Println("\n generating R1CS")
  168. r1cs := program.GenerateReducedR1CS(gates)
  169. fmt.Println(r1cs.A)
  170. fmt.Println(r1cs.B)
  171. fmt.Println(r1cs.C)
  172. for _, io := range test.io {
  173. inputs := io.inputs
  174. fmt.Println("input")
  175. fmt.Println(inputs)
  176. w := CalculateWitness(inputs, r1cs)
  177. fmt.Println("witness")
  178. fmt.Println(w)
  179. assert.Equal(t, io.result, w[program.GlobalInputCount()])
  180. }
  181. }
  182. }