You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
6.1 KiB

5 years ago
5 years ago
  1. package circuitcompiler
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/big"
  6. "strconv"
  7. "github.com/arnaucube/go-snark/r1csqap"
  8. )
  9. // Circuit is the data structure of the compiled circuit
  10. type Circuit struct {
  11. NVars int
  12. NPublic int
  13. NSignals int
  14. PrivateInputs []string
  15. PublicInputs []string
  16. Signals []string
  17. Witness []*big.Int
  18. Constraints []Constraint
  19. R1CS struct {
  20. A [][]*big.Int
  21. B [][]*big.Int
  22. C [][]*big.Int
  23. }
  24. }
  25. // Constraint is the data structure of a flat code operation
  26. type Constraint struct {
  27. // v1 op v2 = out
  28. Op string
  29. V1 string
  30. V2 string
  31. Out string
  32. Literal string
  33. PrivateInputs []string // in func declaration case
  34. PublicInputs []string // in func declaration case
  35. }
  36. func indexInArray(arr []string, e string) int {
  37. for i, a := range arr {
  38. if a == e {
  39. return i
  40. }
  41. }
  42. return -1
  43. }
  44. func isValue(a string) (bool, int) {
  45. v, err := strconv.Atoi(a)
  46. if err != nil {
  47. return false, 0
  48. }
  49. return true, v
  50. }
  51. func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  52. isVal, value := isValue(v)
  53. valueBigInt := big.NewInt(int64(value))
  54. if isVal {
  55. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  56. } else {
  57. if !used[v] {
  58. panic(errors.New("using variable before it's set"))
  59. }
  60. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
  61. }
  62. return arr, used
  63. }
  64. func insertVarNeg(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
  65. isVal, value := isValue(v)
  66. valueBigInt := big.NewInt(int64(value))
  67. if isVal {
  68. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  69. } else {
  70. if !used[v] {
  71. panic(errors.New("using variable before it's set"))
  72. }
  73. arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(-1)))
  74. }
  75. return arr, used
  76. }
  77. // GenerateR1CS generates the R1CS polynomials from the Circuit
  78. func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
  79. // from flat code to R1CS
  80. var a [][]*big.Int
  81. var b [][]*big.Int
  82. var c [][]*big.Int
  83. used := make(map[string]bool)
  84. for _, constraint := range circ.Constraints {
  85. aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  86. bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  87. cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  88. // if existInArray(constraint.Out) {
  89. if used[constraint.Out] {
  90. // panic(errors.New("out variable already used: " + constraint.Out))
  91. fmt.Println("variable already used")
  92. }
  93. used[constraint.Out] = true
  94. if constraint.Op == "in" {
  95. for i := 0; i <= len(circ.PublicInputs); i++ {
  96. aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
  97. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
  98. bConstraint[0] = big.NewInt(int64(1))
  99. }
  100. continue
  101. } else if constraint.Op == "+" {
  102. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  103. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  104. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
  105. bConstraint[0] = big.NewInt(int64(1))
  106. } else if constraint.Op == "-" {
  107. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  108. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V1, used)
  109. aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V2, used)
  110. bConstraint[0] = big.NewInt(int64(1))
  111. } else if constraint.Op == "*" {
  112. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  113. aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
  114. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  115. } else if constraint.Op == "/" {
  116. cConstraint, used = insertVar(cConstraint, circ.Signals, constraint.V1, used)
  117. cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
  118. bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
  119. }
  120. a = append(a, aConstraint)
  121. b = append(b, bConstraint)
  122. c = append(c, cConstraint)
  123. }
  124. circ.R1CS.A = a
  125. circ.R1CS.B = b
  126. circ.R1CS.C = c
  127. return a, b, c
  128. }
  129. func grabVar(signals []string, w []*big.Int, vStr string) *big.Int {
  130. isVal, v := isValue(vStr)
  131. vBig := big.NewInt(int64(v))
  132. if isVal {
  133. return vBig
  134. } else {
  135. return w[indexInArray(signals, vStr)]
  136. }
  137. }
  138. type Inputs struct {
  139. Private []*big.Int
  140. Publics []*big.Int
  141. }
  142. // CalculateWitness calculates the Witness of a Circuit based on the given inputs
  143. // witness = [ one, output, publicInputs, privateInputs, ...]
  144. func (circ *Circuit) CalculateWitness(privateInputs []*big.Int, publicInputs []*big.Int) ([]*big.Int, error) {
  145. if len(privateInputs) != len(circ.PrivateInputs) {
  146. return []*big.Int{}, errors.New("given privateInputs != circuit.PublicInputs")
  147. }
  148. if len(publicInputs) != len(circ.PublicInputs) {
  149. return []*big.Int{}, errors.New("given publicInputs != circuit.PublicInputs")
  150. }
  151. w := r1csqap.ArrayOfBigZeros(len(circ.Signals))
  152. w[0] = big.NewInt(int64(1))
  153. for i, input := range publicInputs {
  154. fmt.Println(i + 1)
  155. fmt.Println(input)
  156. w[i+1] = input
  157. }
  158. for i, input := range privateInputs {
  159. fmt.Println(i + len(publicInputs) + 1)
  160. w[i+len(publicInputs)+1] = input
  161. }
  162. for _, constraint := range circ.Constraints {
  163. if constraint.Op == "in" {
  164. } else if constraint.Op == "+" {
  165. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  166. } else if constraint.Op == "-" {
  167. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Sub(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  168. } else if constraint.Op == "*" {
  169. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Mul(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  170. } else if constraint.Op == "/" {
  171. w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Div(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
  172. }
  173. }
  174. return w, nil
  175. }