You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
5.4 KiB

  1. package circuitcompiler
  2. import (
  3. "errors"
  4. "math/big"
  5. "strconv"
  6. "github.com/arnaucube/go-snark/r1csqap"
  7. )
  8. // Circuit is the data structure of the compiled circuit
  9. type Circuit struct {
  10. NVars int
  11. NPublic int
  12. NSignals int
  13. Inputs []string
  14. Signals []string
  15. Witness []*big.Int
  16. Constraints []Constraint
  17. R1CS struct {
  18. A [][]*big.Int
  19. B [][]*big.Int
  20. C [][]*big.Int
  21. }
  22. }
  23. // Constraint is the data structure of a flat code operation
  24. type Constraint struct {
  25. // v1 op v2 = out
  26. Op string
  27. V1 string
  28. V2 string
  29. Out string
  30. Literal string
  31. Inputs []string // in func declaration case
  32. }
  33. func indexInArray(arr []string, e string) int {
  34. for i, a := range arr {
  35. if a == e {
  36. return i
  37. }
  38. }
  39. return -1
  40. }
  41. func isValue(a string) (bool, int) {
  42. v, err := strconv.Atoi(a)
  43. if err != nil {
  44. return false, 0
  45. }
  46. return true, v
  47. }
  48. func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  49. isVal, value := isValue(v)
  50. valueBigInt := big.NewInt(int64(value))
  51. if isVal {
  52. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  53. } else {
  54. if !used[v] {
  55. panic(errors.New("using variable before it's set"))
  56. }
  57. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
  58. }
  59. return arr, used
  60. }
  61. func insertVarNeg(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  62. isVal, value := isValue(v)
  63. valueBigInt := big.NewInt(int64(value))
  64. if isVal {
  65. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  66. } else {
  67. if !used[v] {
  68. panic(errors.New("using variable before it's set"))
  69. }
  70. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(-1)))
  71. }
  72. return arr, used
  73. }
  74. // GenerateR1CS generates the R1CS polynomials from the Circuit
  75. func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
  76. // from flat code to R1CS
  77. var a [][]*big.Int
  78. var b [][]*big.Int
  79. var c [][]*big.Int
  80. used := make(map[string]bool)
  81. for _, constraint := range circ.Constraints {
  82. aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  83. bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  84. cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  85. // if existInArray(constraint.Out) {
  86. if used[constraint.Out] {
  87. panic(errors.New("out variable already used: " + constraint.Out))
  88. }
  89. used[constraint.Out] = true
  90. if constraint.Op == "in" {
  91. for i := 0; i < len(constraint.Inputs); i++ {
  92. aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
  93. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
  94. bConstraint[0] = big.NewInt(int64(1))
  95. }
  96. continue
  97. } else if constraint.Op == "+" {
  98. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  99. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  100. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
  101. bConstraint[0] = big.NewInt(int64(1))
  102. } else if constraint.Op == "-" {
  103. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  104. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V1, used)
  105. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V2, used)
  106. bConstraint[0] = big.NewInt(int64(1))
  107. } else if constraint.Op == "*" {
  108. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  109. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  110. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  111. } else if constraint.Op == "/" {
  112. cConstraint, used = insertVar(cConstraint, circ.Signals, constraint.V1, used)
  113. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  114. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  115. }
  116. a = append(a, aConstraint)
  117. b = append(b, bConstraint)
  118. c = append(c, cConstraint)
  119. }
  120. return a, b, c
  121. }
  122. func grabVar(signals []string, w []*big.Int, vStr string) *big.Int {
  123. isVal, v := isValue(vStr)
  124. vBig := big.NewInt(int64(v))
  125. if isVal {
  126. return vBig
  127. } else {
  128. return w[indexInArray(signals, vStr)]
  129. }
  130. }
  131. // CalculateWitness calculates the Witness of a Circuit based on the given inputs
  132. func (circ *Circuit) CalculateWitness(inputs []*big.Int) []*big.Int {
  133. w := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  134. w[0] = big.NewInt(int64(1))
  135. for i, input := range inputs {
  136. w[i+1] = input
  137. }
  138. for _, constraint := range circ.Constraints {
  139. if constraint.Op == "in" {
  140. } else if constraint.Op == "+" {
  141. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  142. } else if constraint.Op == "-" {
  143. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Sub(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  144. } else if constraint.Op == "*" {
  145. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Mul(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  146. } else if constraint.Op == "/" {
  147. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Div(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  148. }
  149. }
  150. return w
  151. }