You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

471 lines
12 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/mottla/go-snark/r1csqap"
  5. "math/big"
  6. )
  7. type Program struct {
  8. functions map[string]*Circuit
  9. signals []string
  10. globalInputs []*Constraint
  11. R1CS struct {
  12. A [][]*big.Int
  13. B [][]*big.Int
  14. C [][]*big.Int
  15. }
  16. }
  17. func (p *Program) PrintContraintTrees() {
  18. for k, v := range p.functions {
  19. fmt.Println(k)
  20. PrintTree(v.root)
  21. }
  22. }
  23. func (p *Program) BuildConstraintTrees() {
  24. functionRootMap := make(map[string]*gate)
  25. for _, circuit := range p.functions {
  26. //circuit.addConstraint(p.oneConstraint())
  27. fName := composeNewFunction(circuit.Name, circuit.Inputs)
  28. root := &gate{value: circuit.constraintMap[fName]}
  29. functionRootMap[fName] = root
  30. circuit.root = root
  31. }
  32. for _, circuit := range p.functions {
  33. buildTree(circuit.constraintMap, circuit.root)
  34. }
  35. return
  36. }
  37. func buildTree(con map[string]*Constraint, g *gate) {
  38. if _, ex := con[g.value.Out]; ex {
  39. if g.OperationType()&(IN|CONST) != 0 {
  40. return
  41. }
  42. } else {
  43. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  44. }
  45. if g.OperationType() == FUNC {
  46. g.funcInputs = []*gate{}
  47. for _, in := range g.value.Inputs {
  48. if constr, ex := con[in]; ex {
  49. newGate := &gate{value: constr}
  50. g.funcInputs = append(g.funcInputs, newGate)
  51. buildTree(con, newGate)
  52. } else {
  53. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  54. }
  55. }
  56. return
  57. }
  58. if constr, ex := con[g.value.V1]; ex {
  59. g.addLeft(constr)
  60. buildTree(con, g.left)
  61. } else {
  62. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  63. }
  64. if constr, ex := con[g.value.V2]; ex {
  65. g.addRight(constr)
  66. buildTree(con, g.right)
  67. } else {
  68. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  69. }
  70. }
  71. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  72. mGatesUsed := make(map[string]bool)
  73. orderedmGates = []gate{}
  74. functionRootMap := make(map[string]*gate)
  75. for k, v := range p.functions {
  76. functionRootMap[k] = v.root
  77. }
  78. functionRenamer := func(c *Constraint) *gate {
  79. if c.Op != FUNC {
  80. panic("not a function")
  81. }
  82. if b, name, in := isFunction(c.Out); b {
  83. if k, v := p.functions[name]; v {
  84. //fmt.Println("unrenamed thing")
  85. //PrintTree(k.root)
  86. k.renameInputs(in)
  87. //fmt.Println("renamed thing")
  88. //PrintTree(k.root)
  89. return k.root
  90. }
  91. } else {
  92. panic("not a function dude")
  93. }
  94. return nil
  95. }
  96. traverseCombinedMultiplicationGates(p.getMainCircut().root, mGatesUsed, &orderedmGates, functionRootMap, functionRenamer, false, false)
  97. //for _, g := range mGates {
  98. // orderedmGates[len(orderedmGates)-1-g.index] = g
  99. //}
  100. return orderedmGates
  101. }
  102. func traverseCombinedMultiplicationGates(root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, functionRootMap map[string]*gate, functionRenamer func(c *Constraint) *gate, negate bool, inverse bool) {
  103. //if root == nil {
  104. // return
  105. //}
  106. //fmt.Printf("\n%p",mGatesUsed)
  107. if root.OperationType() == FUNC {
  108. //if a input has already been built, we let this subroutine know
  109. //newMap := make(map[string]bool)
  110. for _, in := range root.funcInputs {
  111. if _, ex := mGatesUsed[in.value.Out]; ex {
  112. //newMap[in.value.Out] = true
  113. } else {
  114. traverseCombinedMultiplicationGates(in, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
  115. }
  116. }
  117. //mGatesUsed[root.value.Out] = true
  118. traverseCombinedMultiplicationGates(functionRenamer(root.value), mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
  119. } else {
  120. if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
  121. traverseCombinedMultiplicationGates(root.left, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, negate, inverse)
  122. }
  123. if _, alreadyComputed := mGatesUsed[root.value.V2]; !alreadyComputed && root.OperationType()&(IN|CONST) == 0 {
  124. traverseCombinedMultiplicationGates(root.right, mGatesUsed, orderedmGates, functionRootMap, functionRenamer, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  125. }
  126. }
  127. if root.OperationType() == MULTIPLY {
  128. _, n, _ := isFunction(root.value.Out)
  129. if (root.left.OperationType()|root.right.OperationType())&CONST != 0 && n != "main" {
  130. return
  131. }
  132. root.leftIns = make(map[string]int)
  133. collectAtomsInSubtree(root.left, mGatesUsed, 1, root.leftIns, functionRootMap, negate, inverse)
  134. root.rightIns = make(map[string]int)
  135. //if root.left.value.Out== root.right.value.Out{
  136. // //note this is not a full copy, but shouldnt be a problem
  137. // root.rightIns= root.leftIns
  138. //}else{
  139. // collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  140. //}
  141. collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  142. root.index = len(mGatesUsed)
  143. mGatesUsed[root.value.Out] = true
  144. rootGate := cloneGate(root)
  145. *orderedmGates = append(*orderedmGates, *rootGate)
  146. }
  147. //TODO optimize if output is not a multipication gate
  148. }
  149. func collectAtomsInSubtree(g *gate, mGatesUsed map[string]bool, multiplicative int, in map[string]int, functionRootMap map[string]*gate, negate bool, invert bool) {
  150. if g == nil {
  151. return
  152. }
  153. if _, ex := mGatesUsed[g.value.Out]; ex {
  154. addToMap(g.value.Out, multiplicative, in, negate)
  155. return
  156. }
  157. if g.OperationType()&(IN|CONST) != 0 {
  158. addToMap(g.value.Out, multiplicative, in, negate)
  159. return
  160. }
  161. if g.OperationType()&(MULTIPLY) != 0 {
  162. b1, v1 := isValue(g.value.V1)
  163. b2, v2 := isValue(g.value.V2)
  164. if b1 && !b2 {
  165. multiplicative *= v1
  166. collectAtomsInSubtree(g.right, mGatesUsed, multiplicative, in, functionRootMap, Xor(negate, g.value.negate), invert)
  167. return
  168. } else if !b1 && b2 {
  169. multiplicative *= v2
  170. collectAtomsInSubtree(g.left, mGatesUsed, multiplicative, in, functionRootMap, negate, invert)
  171. return
  172. } else if b1 && b2 {
  173. panic("multiply constants not supported yet")
  174. } else {
  175. panic("werird")
  176. }
  177. }
  178. if g.OperationType() == FUNC {
  179. if b, name, _ := isFunction(g.value.Out); b {
  180. collectAtomsInSubtree(functionRootMap[name], mGatesUsed, multiplicative, in, functionRootMap, negate, invert)
  181. } else {
  182. panic("function expected")
  183. }
  184. }
  185. collectAtomsInSubtree(g.left, mGatesUsed, multiplicative, in, functionRootMap, negate, invert)
  186. collectAtomsInSubtree(g.right, mGatesUsed, multiplicative, in, functionRootMap, Xor(negate, g.value.negate), invert)
  187. }
  188. func addOneToMap(value string, in map[string]int, negate bool) {
  189. addToMap(value, 1, in, negate)
  190. }
  191. func addToMap(value string, val int, in map[string]int, negate bool) {
  192. if negate {
  193. in[value] = (in[value] - 1) * val
  194. } else {
  195. in[value] = (in[value] + 1) * val
  196. }
  197. }
  198. //copies a gate neglecting its references to other gates
  199. func cloneGate(in *gate) (out *gate) {
  200. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  201. nRightins := make(map[string]int)
  202. nLeftInst := make(map[string]int)
  203. for k, v := range in.rightIns {
  204. nRightins[k] = v
  205. }
  206. for k, v := range in.leftIns {
  207. nLeftInst[k] = v
  208. }
  209. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  210. }
  211. func (p *Program) getMainCircut() *Circuit {
  212. return p.functions["main"]
  213. }
  214. func (p *Program) addGlobalInput(c *Constraint) {
  215. p.globalInputs = append(p.globalInputs, c)
  216. }
  217. func NewProgramm() *Program {
  218. //return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: PLUS, V1:"1",V2:"0", Out: "one"}}}
  219. return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: IN, Out: "one"}}}
  220. }
  221. //func (p *Program) oneConstraint() *Constraint {
  222. // if p.globalInputs[0].Out != "one" {
  223. // panic("'one' should be first global input")
  224. // }
  225. // return p.globalInputs[0]
  226. //}
  227. func (p *Program) addSignal(name string) {
  228. p.signals = append(p.signals, name)
  229. }
  230. func (p *Program) addFunction(constraint *Constraint) (c *Circuit) {
  231. name := constraint.Out
  232. fmt.Println("try to add function ", name)
  233. b, name2, _ := isFunction(name)
  234. if !b {
  235. panic(fmt.Sprintf("not a function: %v", constraint))
  236. }
  237. name = name2
  238. if _, ex := p.functions[name]; ex {
  239. panic("function already declared")
  240. }
  241. c = newCircuit(name)
  242. p.functions[name] = c
  243. //if constraint.Literal == "main" {
  244. for _, in := range constraint.Inputs {
  245. newConstr := &Constraint{
  246. Op: IN,
  247. Out: in,
  248. }
  249. if name == "main" {
  250. p.addGlobalInput(newConstr)
  251. }
  252. c.addConstraint(newConstr)
  253. }
  254. c.Inputs = constraint.Inputs
  255. return
  256. }
  257. // GenerateR1CS generates the R1CS polynomials from the Circuit
  258. func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
  259. // from flat code to R1CS
  260. offset := len(p.globalInputs)
  261. // one + in1 +in2+... + gate1 + gate2 .. + out
  262. size := offset + len(mGates)
  263. indexMap := make(map[string]int)
  264. //circ.Signals = []string{"one"}
  265. for i, v := range p.globalInputs {
  266. indexMap[v.Out] = i
  267. //circ.Signals = append(circ.Signals, v)
  268. }
  269. for i, v := range mGates {
  270. indexMap[v.value.Out] = i + offset
  271. //circ.Signals = append(circ.Signals, v.value.Out)
  272. }
  273. //circ.NVars = len(circ.Signals)
  274. //circ.NSignals = len(circ.Signals)
  275. for _, gate := range mGates {
  276. if gate.OperationType() == MULTIPLY {
  277. aConstraint := r1csqap.ArrayOfBigZeros(size)
  278. bConstraint := r1csqap.ArrayOfBigZeros(size)
  279. cConstraint := r1csqap.ArrayOfBigZeros(size)
  280. //if len(gate.leftIns)>=len(gate.rightIns){
  281. // for leftInput, _ := range gate.leftIns {
  282. // if v, ex := gate.rightIns[leftInput]; ex {
  283. // gate.leftIns[leftInput] *= v
  284. // gate.rightIns[leftInput] = 1
  285. //
  286. // }
  287. // }
  288. //}else{
  289. // for rightInput, _ := range gate.rightIns {
  290. // if v, ex := gate.leftIns[rightInput]; ex {
  291. // gate.rightIns[rightInput] *= v
  292. // gate.leftIns[rightInput] = 1
  293. // }
  294. // }
  295. //}
  296. for leftInput, val := range gate.leftIns {
  297. insertVar3(aConstraint, val, leftInput, indexMap[leftInput])
  298. }
  299. for rightInput, val := range gate.rightIns {
  300. insertVar3(bConstraint, val, rightInput, indexMap[rightInput])
  301. }
  302. cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
  303. if gate.value.invert {
  304. a = append(a, cConstraint)
  305. b = append(b, bConstraint)
  306. c = append(c, aConstraint)
  307. } else {
  308. a = append(a, aConstraint)
  309. b = append(b, bConstraint)
  310. c = append(c, cConstraint)
  311. }
  312. } else {
  313. panic("not a m gate")
  314. }
  315. }
  316. p.R1CS.A = a
  317. p.R1CS.B = b
  318. p.R1CS.C = c
  319. return a, b, c
  320. }
  321. func insertVar3(arr []*big.Int, val int, input string, index int) {
  322. isVal, value := isValue(input)
  323. var valueBigInt *big.Int
  324. if isVal {
  325. valueBigInt = big.NewInt(int64(value))
  326. arr[0] = new(big.Int).Add(arr[0], valueBigInt)
  327. } else {
  328. //if !indexMap[leftInput] {
  329. // panic(errors.New("using variable before it's set"))
  330. //}
  331. valueBigInt = big.NewInt(int64(val))
  332. arr[index] = new(big.Int).Add(arr[index], valueBigInt)
  333. }
  334. }
  335. func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
  336. if len(p.globalInputs)-1 != len(input) {
  337. panic("input do not match the required inputs")
  338. }
  339. witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
  340. set := make([]bool, len(witness))
  341. witness[0] = big.NewInt(int64(1))
  342. set[0] = true
  343. for i := range input {
  344. witness[i+1] = input[i]
  345. set[i+1] = true
  346. }
  347. zero := big.NewInt(int64(0))
  348. for i := 0; i < len(p.R1CS.A); i++ {
  349. gatesLeftInputs := p.R1CS.A[i]
  350. gatesRightInputs := p.R1CS.B[i]
  351. gatesOutputs := p.R1CS.C[i]
  352. sumLeft := big.NewInt(int64(0))
  353. sumRight := big.NewInt(int64(0))
  354. sumOut := big.NewInt(int64(0))
  355. index := -1
  356. division := false
  357. for j, val := range gatesLeftInputs {
  358. if val.Cmp(zero) != 0 {
  359. if !set[j] {
  360. index = j
  361. division = true
  362. break
  363. }
  364. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  365. }
  366. }
  367. for j, val := range gatesRightInputs {
  368. if val.Cmp(zero) != 0 {
  369. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  370. }
  371. }
  372. for j, val := range gatesOutputs {
  373. if val.Cmp(zero) != 0 {
  374. if !set[j] {
  375. if index != -1 {
  376. panic("invalid R1CS form")
  377. }
  378. index = j
  379. break
  380. }
  381. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  382. }
  383. }
  384. if !division {
  385. set[index] = true
  386. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  387. } else {
  388. b := sumRight.Int64()
  389. c := sumOut.Int64()
  390. set[index] = true
  391. witness[index] = big.NewInt(c / b)
  392. }
  393. }
  394. return
  395. }