You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
6.0 KiB

5 years ago
5 years ago
  1. package circuitcompiler
  2. import (
  3. "errors"
  4. "math/big"
  5. "strconv"
  6. "github.com/arnaucube/go-snark-study/r1csqap"
  7. )
  8. // Circuit is the data structure of the compiled circuit
  9. type Circuit struct {
  10. NVars int
  11. NPublic int
  12. NSignals int
  13. PrivateInputs []string
  14. PublicInputs []string
  15. Signals []string
  16. Witness []*big.Int
  17. Constraints []Constraint
  18. R1CS struct {
  19. A [][]*big.Int
  20. B [][]*big.Int
  21. C [][]*big.Int
  22. }
  23. }
  24. // Constraint is the data structure of a flat code operation
  25. type Constraint struct {
  26. // v1 op v2 = out
  27. Op string
  28. V1 string
  29. V2 string
  30. Out string
  31. Literal string
  32. PrivateInputs []string // in func declaration case
  33. PublicInputs []string // in func declaration case
  34. }
  35. func indexInArray(arr []string, e string) int {
  36. for i, a := range arr {
  37. if a == e {
  38. return i
  39. }
  40. }
  41. return -1
  42. }
  43. func isValue(a string) (bool, int) {
  44. v, err := strconv.Atoi(a)
  45. if err != nil {
  46. return false, 0
  47. }
  48. return true, v
  49. }
  50. func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  51. isVal, value := isValue(v)
  52. valueBigInt := big.NewInt(int64(value))
  53. if isVal {
  54. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  55. } else {
  56. if !used[v] {
  57. panic(errors.New("using variable before it's set"))
  58. }
  59. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
  60. }
  61. return arr, used
  62. }
  63. func insertVarNeg(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  64. isVal, value := isValue(v)
  65. valueBigInt := big.NewInt(int64(value))
  66. if isVal {
  67. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  68. } else {
  69. if !used[v] {
  70. panic(errors.New("using variable before it's set"))
  71. }
  72. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(-1)))
  73. }
  74. return arr, used
  75. }
  76. // GenerateR1CS generates the R1CS polynomials from the Circuit
  77. func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
  78. // from flat code to R1CS
  79. var a [][]*big.Int
  80. var b [][]*big.Int
  81. var c [][]*big.Int
  82. used := make(map[string]bool)
  83. for _, constraint := range circ.Constraints {
  84. aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  85. bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  86. cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  87. // if existInArray(constraint.Out) {
  88. // if used[constraint.Out] {
  89. // panic(errors.New("out variable already used: " + constraint.Out))
  90. // }
  91. used[constraint.Out] = true
  92. if constraint.Op == "in" {
  93. for i := 0; i <= len(circ.PublicInputs); i++ {
  94. aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
  95. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
  96. bConstraint[0] = big.NewInt(int64(1))
  97. }
  98. continue
  99. } else if constraint.Op == "+" {
  100. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  101. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  102. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
  103. bConstraint[0] = big.NewInt(int64(1))
  104. } else if constraint.Op == "-" {
  105. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  106. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V1, used)
  107. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V2, used)
  108. bConstraint[0] = big.NewInt(int64(1))
  109. } else if constraint.Op == "*" {
  110. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  111. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  112. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  113. } else if constraint.Op == "/" {
  114. cConstraint, used = insertVar(cConstraint, circ.Signals, constraint.V1, used)
  115. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  116. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  117. }
  118. a = append(a, aConstraint)
  119. b = append(b, bConstraint)
  120. c = append(c, cConstraint)
  121. }
  122. circ.R1CS.A = a
  123. circ.R1CS.B = b
  124. circ.R1CS.C = c
  125. return a, b, c
  126. }
  127. func grabVar(signals []string, w []*big.Int, vStr string) *big.Int {
  128. isVal, v := isValue(vStr)
  129. vBig := big.NewInt(int64(v))
  130. if isVal {
  131. return vBig
  132. } else {
  133. return w[indexInArray(signals, vStr)]
  134. }
  135. }
  136. type Inputs struct {
  137. Private []*big.Int
  138. Public []*big.Int
  139. }
  140. // CalculateWitness calculates the Witness of a Circuit based on the given inputs
  141. // witness = [ one, output, publicInputs, privateInputs, ...]
  142. func (circ *Circuit) CalculateWitness(privateInputs []*big.Int, publicInputs []*big.Int) ([]*big.Int, error) {
  143. if len(privateInputs) != len(circ.PrivateInputs) {
  144. return []*big.Int{}, errors.New("given privateInputs != circuit.PublicInputs")
  145. }
  146. if len(publicInputs) != len(circ.PublicInputs) {
  147. return []*big.Int{}, errors.New("given publicInputs != circuit.PublicInputs")
  148. }
  149. w := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  150. w[0] = big.NewInt(int64(1))
  151. for i, input := range publicInputs {
  152. w[i+1] = input
  153. }
  154. for i, input := range privateInputs {
  155. w[i+len(publicInputs)+1] = input
  156. }
  157. for _, constraint := range circ.Constraints {
  158. if constraint.Op == "in" {
  159. } else if constraint.Op == "+" {
  160. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  161. } else if constraint.Op == "-" {
  162. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Sub(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  163. } else if constraint.Op == "*" {
  164. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Mul(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  165. } else if constraint.Op == "/" {
  166. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Div(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  167. }
  168. }
  169. return w, nil
  170. }