You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

382 lines
9.5 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/mottla/go-snark/r1csqap"
  5. "math/big"
  6. )
  7. type Program struct {
  8. functions map[string]*Circuit
  9. signals []string
  10. globalInputs []*Constraint
  11. R1CS struct {
  12. A [][]*big.Int
  13. B [][]*big.Int
  14. C [][]*big.Int
  15. }
  16. }
  17. func (p *Program) PrintContraintTrees() {
  18. for k, v := range p.functions {
  19. fmt.Println(k)
  20. PrintTree(v.root)
  21. }
  22. }
  23. func (p *Program) BuildConstraintTrees() {
  24. functionRootMap := make(map[string]*gate)
  25. for _, circuit := range p.functions {
  26. circuit.addConstraint(p.oneConstraint())
  27. fName := composeNewFunction(circuit.Name, circuit.Inputs)
  28. root := &gate{value: circuit.constraintMap[fName]}
  29. functionRootMap[fName] = root
  30. circuit.root = root
  31. }
  32. for _, circuit := range p.functions {
  33. buildTree(circuit.constraintMap, circuit.root)
  34. }
  35. return
  36. }
  37. func buildTree(con map[string]*Constraint, g *gate) {
  38. if _, ex := con[g.value.Out]; ex {
  39. if g.OperationType()&(IN|CONST) != 0 {
  40. return
  41. }
  42. } else {
  43. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  44. }
  45. if g.OperationType() == FUNC {
  46. g.funcInputs = []*gate{}
  47. for _, in := range g.value.Inputs {
  48. if constr, ex := con[in]; ex {
  49. newGate := &gate{value: constr}
  50. g.funcInputs = append(g.funcInputs, newGate)
  51. buildTree(con, newGate)
  52. } else {
  53. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  54. }
  55. }
  56. return
  57. }
  58. if constr, ex := con[g.value.V1]; ex {
  59. g.addLeft(constr)
  60. buildTree(con, g.left)
  61. } else {
  62. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  63. }
  64. if constr, ex := con[g.value.V2]; ex {
  65. g.addRight(constr)
  66. buildTree(con, g.right)
  67. } else {
  68. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  69. }
  70. }
  71. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  72. mGatesUsed := make(map[string]bool)
  73. orderedmGates = []gate{}
  74. functionRootMap := make(map[string]*gate)
  75. for k, v := range p.functions {
  76. functionRootMap[k] = v.root
  77. }
  78. functionRenamer := func(c *Constraint) *gate {
  79. if c.Op != FUNC {
  80. panic("not a function")
  81. }
  82. if b, name, in := isFunction(c.Out); b {
  83. if k, v := p.functions[name]; v {
  84. //fmt.Println("unrenamed thing")
  85. //PrintTree(k.root)
  86. k.renameInputs(in)
  87. //fmt.Println("renamed thing")
  88. //PrintTree(k.root)
  89. return k.root
  90. }
  91. } else {
  92. panic("not a function dude")
  93. }
  94. return nil
  95. }
  96. traverseCombinedMultiplicationGates(p.getMainCircut().root, mGatesUsed, &orderedmGates, functionRootMap, functionRenamer, false, false)
  97. //for _, g := range mGates {
  98. // orderedmGates[len(orderedmGates)-1-g.index] = g
  99. //}
  100. return orderedmGates
  101. }
  102. func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, functionRootMap map[string]*gate, functionRenamer func(c *Constraint) *gate, negate bool, inverse bool) {
  103. //if root == nil {
  104. // return
  105. //}
  106. if root.OperationType() == FUNC {
  107. //if a input has already been built, we let this subroutine know
  108. newMap := make(map[string]bool)
  109. for _, in := range root.funcInputs {
  110. if _, ex := mGatesUsed[in.value.Out]; ex {
  111. newMap[in.value.Out] = true
  112. } else {
  113. traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
  114. }
  115. }
  116. //mGatesUsed[root.value.Out] = true
  117. traverseCombinedMultiplicationGates(functionRenamer(root.value), newMap, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
  118. } else {
  119. if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
  120. traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
  121. }
  122. if _, alreadyComputed := mGatesUsed[root.value.V2]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
  123. traverseCombinedMultiplicationGates(root.right, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  124. }
  125. }
  126. if root.OperationType() == MULTIPLY {
  127. root.leftIns = make(map[string]int)
  128. collectAtomsInSubtree(root.left, root.leftIns, functionRootMap, negate, inverse)
  129. root.rightIns = make(map[string]int)
  130. collectAtomsInSubtree(root.right, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  131. root.index = len(mGatesUsed)
  132. mGatesUsed[root.value.Out] = true
  133. rootGate := cloneGate(root)
  134. *orderedmGates = append(*orderedmGates, *rootGate)
  135. }
  136. //TODO optimize if output is not a multipication gate
  137. }
  138. //copies a gate neglecting its references to other gates
  139. func cloneGate(in *gate) (out *gate) {
  140. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  141. nRightins := make(map[string]int)
  142. nLeftInst := make(map[string]int)
  143. for k, v := range in.rightIns {
  144. nRightins[k] = v
  145. }
  146. for k, v := range in.leftIns {
  147. nLeftInst[k] = v
  148. }
  149. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  150. }
  151. func (p *Program) getMainCircut() *Circuit {
  152. return p.functions["main"]
  153. }
  154. func (p *Program) addGlobalInput(c *Constraint) {
  155. p.globalInputs = append(p.globalInputs, c)
  156. }
  157. func NewProgramm() *Program {
  158. return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: CONST, Out: "one"}}}
  159. }
  160. func (p *Program) oneConstraint() *Constraint {
  161. if p.globalInputs[0].Out != "one" {
  162. panic("'one' should be first global input")
  163. }
  164. return p.globalInputs[0]
  165. }
  166. func (p *Program) addSignal(name string) {
  167. p.signals = append(p.signals, name)
  168. }
  169. func (p *Program) addFunction(constraint *Constraint) (c *Circuit) {
  170. name := constraint.Out
  171. fmt.Println("try to add function ", name)
  172. b, name2, _ := isFunction(name)
  173. if !b {
  174. panic(fmt.Sprintf("not a function: %v", constraint))
  175. }
  176. name = name2
  177. if _, ex := p.functions[name]; ex {
  178. panic("function already declared")
  179. }
  180. c = newCircuit(name)
  181. p.functions[name] = c
  182. //if constraint.Literal == "main" {
  183. for _, in := range constraint.Inputs {
  184. newConstr := &Constraint{
  185. Op: IN,
  186. Out: in,
  187. }
  188. if name == "main" {
  189. p.addGlobalInput(newConstr)
  190. }
  191. c.addConstraint(newConstr)
  192. }
  193. c.Inputs = constraint.Inputs
  194. return
  195. }
  196. // GenerateR1CS generates the R1CS polynomials from the Circuit
  197. func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
  198. // from flat code to R1CS
  199. offset := len(p.globalInputs)
  200. // one + in1 +in2+... + gate1 + gate2 .. + out
  201. size := offset + len(mGates)
  202. indexMap := make(map[string]int)
  203. //circ.Signals = []string{"one"}
  204. for i, v := range p.globalInputs {
  205. indexMap[v.Out] = i
  206. //circ.Signals = append(circ.Signals, v)
  207. }
  208. for i, v := range mGates {
  209. indexMap[v.value.Out] = i + offset
  210. //circ.Signals = append(circ.Signals, v.value.Out)
  211. }
  212. //circ.NVars = len(circ.Signals)
  213. //circ.NSignals = len(circ.Signals)
  214. for _, gate := range mGates {
  215. if gate.OperationType() == MULTIPLY {
  216. aConstraint := r1csqap.ArrayOfBigZeros(size)
  217. bConstraint := r1csqap.ArrayOfBigZeros(size)
  218. cConstraint := r1csqap.ArrayOfBigZeros(size)
  219. for leftInput, val := range gate.leftIns {
  220. insertVar3(aConstraint, val, leftInput, indexMap[leftInput])
  221. }
  222. for rightInput, val := range gate.rightIns {
  223. insertVar3(bConstraint, val, rightInput, indexMap[rightInput])
  224. }
  225. cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
  226. if gate.value.invert {
  227. a = append(a, cConstraint)
  228. b = append(b, bConstraint)
  229. c = append(c, aConstraint)
  230. } else {
  231. a = append(a, aConstraint)
  232. b = append(b, bConstraint)
  233. c = append(c, cConstraint)
  234. }
  235. } else {
  236. panic("not a m gate")
  237. }
  238. }
  239. p.R1CS.A = a
  240. p.R1CS.B = b
  241. p.R1CS.C = c
  242. return a, b, c
  243. }
  244. func insertVar3(arr []*big.Int, val int, input string, index int) {
  245. isVal, value := isValue(input)
  246. var valueBigInt *big.Int
  247. if isVal {
  248. valueBigInt = big.NewInt(int64(value))
  249. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  250. } else {
  251. //if !indexMap[leftInput] {
  252. // panic(errors.New("using variable before it's set"))
  253. //}
  254. valueBigInt = big.NewInt(int64(val))
  255. arr[index] = new(big.Int).Add(arr[index], valueBigInt)
  256. }
  257. }
  258. func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
  259. if len(p.globalInputs)-1 != len(input) {
  260. panic("input do not match the required inputs")
  261. }
  262. witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
  263. set := make([]bool, len(witness))
  264. witness[0] = big.NewInt(int64(1))
  265. set[0] = true
  266. for i := range input {
  267. witness[i+1] = input[i]
  268. set[i+1] = true
  269. }
  270. zero := big.NewInt(int64(0))
  271. for i := 0; i < len(p.R1CS.A); i++ {
  272. gatesLeftInputs := p.R1CS.A[i]
  273. gatesRightInputs := p.R1CS.B[i]
  274. gatesOutputs := p.R1CS.C[i]
  275. sumLeft := big.NewInt(int64(0))
  276. sumRight := big.NewInt(int64(0))
  277. sumOut := big.NewInt(int64(0))
  278. index := -1
  279. division := false
  280. for j, val := range gatesLeftInputs {
  281. if val.Cmp(zero) != 0 {
  282. if !set[j] {
  283. index = j
  284. division = true
  285. break
  286. }
  287. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  288. }
  289. }
  290. for j, val := range gatesRightInputs {
  291. if val.Cmp(zero) != 0 {
  292. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  293. }
  294. }
  295. for j, val := range gatesOutputs {
  296. if val.Cmp(zero) != 0 {
  297. if !set[j] {
  298. if index != -1 {
  299. panic("invalid R1CS form")
  300. }
  301. index = j
  302. break
  303. }
  304. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  305. }
  306. }
  307. if !division {
  308. set[index] = true
  309. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  310. } else {
  311. b := sumRight.Int64()
  312. c := sumOut.Int64()
  313. set[index] = true
  314. witness[index] = big.NewInt(c / b)
  315. }
  316. }
  317. return
  318. }