You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
5.7 KiB

5 years ago
5 years ago
  1. package circuitcompiler
  2. import (
  3. "errors"
  4. "math/big"
  5. "strconv"
  6. "github.com/arnaucube/go-snark/r1csqap"
  7. )
  8. // Circuit is the data structure of the compiled circuit
  9. type Circuit struct {
  10. NVars int
  11. NPublic int
  12. NSignals int
  13. Inputs []string
  14. Signals []string
  15. PublicSignals []string
  16. Witness []*big.Int
  17. Constraints []Constraint
  18. R1CS struct {
  19. A [][]*big.Int
  20. B [][]*big.Int
  21. C [][]*big.Int
  22. }
  23. }
  24. // Constraint is the data structure of a flat code operation
  25. type Constraint struct {
  26. // v1 op v2 = out
  27. Op string
  28. V1 string
  29. V2 string
  30. Out string
  31. Literal string
  32. Inputs []string // in func declaration case
  33. }
  34. func indexInArray(arr []string, e string) int {
  35. for i, a := range arr {
  36. if a == e {
  37. return i
  38. }
  39. }
  40. return -1
  41. }
  42. func isValue(a string) (bool, int) {
  43. v, err := strconv.Atoi(a)
  44. if err != nil {
  45. return false, 0
  46. }
  47. return true, v
  48. }
  49. func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  50. isVal, value := isValue(v)
  51. valueBigInt := big.NewInt(int64(value))
  52. if isVal {
  53. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  54. } else {
  55. if !used[v] {
  56. panic(errors.New("using variable before it's set"))
  57. }
  58. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
  59. }
  60. return arr, used
  61. }
  62. func insertVarNeg(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  63. isVal, value := isValue(v)
  64. valueBigInt := big.NewInt(int64(value))
  65. if isVal {
  66. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  67. } else {
  68. if !used[v] {
  69. panic(errors.New("using variable before it's set"))
  70. }
  71. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(-1)))
  72. }
  73. return arr, used
  74. }
  75. // GenerateR1CS generates the R1CS polynomials from the Circuit
  76. func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
  77. // from flat code to R1CS
  78. var a [][]*big.Int
  79. var b [][]*big.Int
  80. var c [][]*big.Int
  81. used := make(map[string]bool)
  82. for _, constraint := range circ.Constraints {
  83. aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  84. bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  85. cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  86. // if existInArray(constraint.Out) {
  87. if used[constraint.Out] {
  88. panic(errors.New("out variable already used: " + constraint.Out))
  89. }
  90. used[constraint.Out] = true
  91. if constraint.Op == "in" {
  92. for i := 0; i < len(constraint.Inputs); i++ {
  93. aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
  94. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
  95. bConstraint[0] = big.NewInt(int64(1))
  96. }
  97. continue
  98. } else if constraint.Op == "+" {
  99. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  100. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  101. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
  102. bConstraint[0] = big.NewInt(int64(1))
  103. } else if constraint.Op == "-" {
  104. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  105. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V1, used)
  106. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V2, used)
  107. bConstraint[0] = big.NewInt(int64(1))
  108. } else if constraint.Op == "*" {
  109. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  110. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  111. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  112. } else if constraint.Op == "/" {
  113. cConstraint, used = insertVar(cConstraint, circ.Signals, constraint.V1, used)
  114. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  115. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  116. }
  117. a = append(a, aConstraint)
  118. b = append(b, bConstraint)
  119. c = append(c, cConstraint)
  120. }
  121. circ.R1CS.A = a
  122. circ.R1CS.B = b
  123. circ.R1CS.C = c
  124. return a, b, c
  125. }
  126. func grabVar(signals []string, w []*big.Int, vStr string) *big.Int {
  127. isVal, v := isValue(vStr)
  128. vBig := big.NewInt(int64(v))
  129. if isVal {
  130. return vBig
  131. } else {
  132. return w[indexInArray(signals, vStr)]
  133. }
  134. }
  135. type Inputs struct {
  136. Private []*big.Int
  137. Publics []*big.Int
  138. }
  139. // CalculateWitness calculates the Witness of a Circuit based on the given inputs
  140. // witness = [ one, output, publicInputs, privateInputs, ...]
  141. func (circ *Circuit) CalculateWitness(inputs []*big.Int) ([]*big.Int, error) {
  142. if len(inputs) != len(circ.Inputs) {
  143. return []*big.Int{}, errors.New("given inputs != circuit.Inputs")
  144. }
  145. w := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  146. w[0] = big.NewInt(int64(1))
  147. for i, input := range inputs {
  148. w[i+2] = input
  149. }
  150. for _, constraint := range circ.Constraints {
  151. if constraint.Op == "in" {
  152. } else if constraint.Op == "+" {
  153. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  154. } else if constraint.Op == "-" {
  155. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Sub(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  156. } else if constraint.Op == "*" {
  157. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Mul(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  158. } else if constraint.Op == "/" {
  159. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Div(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  160. }
  161. }
  162. return w, nil
  163. }