You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

608 lines
17 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/mottla/go-snark/bn128"
  5. "github.com/mottla/go-snark/fields"
  6. "github.com/mottla/go-snark/r1csqap"
  7. "math/big"
  8. "sync"
  9. )
  10. type utils struct {
  11. Bn bn128.Bn128
  12. FqR fields.Fq
  13. PF r1csqap.PolynomialField
  14. }
  15. type Program struct {
  16. functions map[string]*Circuit
  17. globalInputs []string
  18. arithmeticEnvironment utils //find a better name
  19. R1CS struct {
  20. A [][]*big.Int
  21. B [][]*big.Int
  22. C [][]*big.Int
  23. }
  24. }
  25. func (p *Program) PrintContraintTrees() {
  26. for k, v := range p.functions {
  27. fmt.Println(k)
  28. PrintTree(v.root)
  29. }
  30. }
  31. func (p *Program) BuildConstraintTrees() {
  32. mainRoot := p.getMainCircuit().root
  33. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  34. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  35. p.getMainCircuit().addConstraint(&newOut)
  36. mainRoot.value.Out = "main@out2"
  37. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  38. }
  39. for _, in := range p.getMainCircuit().Inputs {
  40. p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs))
  41. }
  42. var wg = sync.WaitGroup{}
  43. for _, circuit := range p.functions {
  44. wg.Add(1)
  45. func() {
  46. circuit.buildTree(circuit.root)
  47. wg.Done()
  48. }()
  49. }
  50. wg.Wait()
  51. return
  52. }
  53. func (c *Circuit) buildTree(g *gate) {
  54. if _, ex := c.gateMap[g.value.Out]; ex {
  55. if g.OperationType()&(IN|CONST) != 0 {
  56. return
  57. }
  58. } else {
  59. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  60. }
  61. if g.OperationType() == FUNC {
  62. //g.funcInputs = []*gate{}
  63. for _, in := range g.value.Inputs {
  64. if gate, ex := c.gateMap[in]; ex {
  65. g.funcInputs = append(g.funcInputs, gate)
  66. //note that we do repeated work here. the argument
  67. c.buildTree(gate)
  68. } else {
  69. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  70. }
  71. }
  72. return
  73. }
  74. if constr, ex := c.gateMap[g.value.V1]; ex {
  75. g.left = constr
  76. c.buildTree(g.left)
  77. } else {
  78. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  79. }
  80. if constr, ex := c.gateMap[g.value.V2]; ex {
  81. g.right = constr
  82. c.buildTree(g.right)
  83. } else {
  84. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  85. }
  86. }
  87. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  88. mGatesUsed := make(map[string]bool)
  89. orderedmGates = []gate{}
  90. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, mGatesUsed, &orderedmGates, false, false)
  91. return orderedmGates
  92. }
  93. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, negate bool, inverse bool) (variableEnd bool) {
  94. if root.OperationType() == IN {
  95. return true
  96. }
  97. if root.OperationType() == CONST {
  98. return false
  99. }
  100. if root.OperationType() == FUNC {
  101. nextContext := p.extendedFunctionRenamer(currentCircuit, root.value)
  102. currentCircuit = nextContext
  103. root = nextContext.root
  104. }
  105. originOfVariable := p.functions[getContextFromVariable(root.value.Out)]
  106. if _, alreadyComputed := mGatesUsed[composeNewFunction(root.value.Out, originOfVariable.currentOutputs())]; alreadyComputed {
  107. return true
  108. }
  109. variableEnd = p.r1CSRecursiveBuild(currentCircuit, root.left, mGatesUsed, orderedmGates, negate, inverse)
  110. cons := p.r1CSRecursiveBuild(currentCircuit, root.right, mGatesUsed, orderedmGates, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  111. if root.OperationType() == MULTIPLY {
  112. if !(variableEnd && cons) && !root.value.invert && root != p.getMainCircuit().root {
  113. return variableEnd || cons
  114. }
  115. root.leftIns = p.collectFactors(currentCircuit, root.left, mGatesUsed, false, false)
  116. //if root.left.value.Out== root.right.value.Out{
  117. // //note this is not a full copy, but shouldnt be a problem
  118. // root.rightIns= root.leftIns
  119. //}else{
  120. // collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  121. //}
  122. //root.rightIns = collectAtomsInSubtree3(root.right, mGatesUsed, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  123. root.rightIns = p.collectFactors(currentCircuit, root.right, mGatesUsed, false, false)
  124. root.index = len(mGatesUsed)
  125. var nn = composeNewFunction(root.value.Out, originOfVariable.currentOutputs())
  126. //var nn = root.value.Out
  127. //if _, ex := p.functions[root.value.Out]; ex {
  128. // nn = currentCircuit.currentOutputName()
  129. //}
  130. if _, ex := mGatesUsed[nn]; ex {
  131. panic(fmt.Sprintf("told ya so %v", nn))
  132. }
  133. mGatesUsed[nn] = true
  134. rootGate := cloneGate(root)
  135. rootGate.value.Out = nn
  136. *orderedmGates = append(*orderedmGates, *rootGate)
  137. }
  138. return variableEnd || cons
  139. //TODO optimize if output is not a multipication gate
  140. }
  141. type factor struct {
  142. typ Token
  143. name string
  144. invert, negate bool
  145. multiplicative [2]int
  146. }
  147. func (f factor) String() string {
  148. if f.typ == CONST {
  149. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  150. }
  151. str := f.name
  152. if f.invert {
  153. str += "^-1"
  154. }
  155. if f.negate {
  156. str = "-" + str
  157. }
  158. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  159. }
  160. func mul2DVector(a, b [2]int) [2]int {
  161. return [2]int{a[0] * b[0], a[1] * b[1]}
  162. }
  163. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  164. for _, facLeft := range leftFactors {
  165. for i, facRight := range rightFactors {
  166. if facLeft.typ == CONST && facRight.typ == IN {
  167. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  168. continue
  169. }
  170. if facRight.typ == CONST && facLeft.typ == IN {
  171. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  172. continue
  173. }
  174. if facRight.typ&facLeft.typ == CONST {
  175. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  176. continue
  177. }
  178. //tricky part here
  179. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  180. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  181. if facRight.typ&facLeft.typ == IN {
  182. if facLeft.name == facRight.name {
  183. if facRight.invert != facLeft.invert {
  184. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  185. continue
  186. }
  187. }
  188. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  189. //continue
  190. }
  191. fmt.Println("dsf")
  192. panic("unexpected")
  193. }
  194. }
  195. return rightFactors
  196. }
  197. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  198. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  199. func abs(n int) (val int, positive bool) {
  200. y := n >> 63
  201. return (n ^ y) - y, y == 0
  202. }
  203. //returns the reduced sum of two input factor arrays
  204. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  205. func addFactors(leftFactors, rightFactors []factor) []factor {
  206. var found bool
  207. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  208. for _, facLeft := range leftFactors {
  209. found = false
  210. for i, facRight := range rightFactors {
  211. if facLeft.typ&facRight.typ == CONST {
  212. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  213. if facLeft.negate {
  214. a0 *= -1
  215. }
  216. if facRight.negate {
  217. b0 *= -1
  218. }
  219. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  220. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  221. found = true
  222. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  223. break
  224. }
  225. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  226. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  227. if facLeft.negate {
  228. a0 *= -1
  229. }
  230. if facRight.negate {
  231. b0 *= -1
  232. }
  233. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  234. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  235. found = true
  236. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  237. break
  238. }
  239. }
  240. if !found {
  241. res = append(res, facLeft)
  242. }
  243. }
  244. for _, val := range rightFactors {
  245. if val.multiplicative[0] != 0 {
  246. res = append(res, val)
  247. }
  248. }
  249. return res
  250. }
  251. func (p *Program) collectFactors(contextCircut *Circuit, node *gate, mGatesUsed map[string]bool, negate bool, invert bool) []factor {
  252. if node.OperationType() == CONST {
  253. b1, v1 := isValue(node.value.Out)
  254. if !b1 {
  255. panic("not a constant")
  256. }
  257. if invert {
  258. return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{1, v1}}}
  259. }
  260. return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{v1, 1}}}
  261. }
  262. if node.OperationType() == FUNC {
  263. nextContext := p.extendedFunctionRenamer(contextCircut, node.value)
  264. //if _, ex := mGatesUsed[nextContext.currentOutputName()]; ex {
  265. // return []factor{{typ: IN, name: nextContext.currentOutputName(), invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  266. //}
  267. contextCircut = nextContext
  268. node = nextContext.root
  269. }
  270. originOfVariable := p.functions[getContextFromVariable(node.value.Out)]
  271. if originOfVariable == nil {
  272. fmt.Println("asdf")
  273. }
  274. lookingFOr := composeNewFunction(node.value.Out, originOfVariable.currentOutputs())
  275. //if _, ex := mGatesUsed[node.value.Out]; ex {
  276. // return []factor{{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  277. //}
  278. if node.OperationType() == IN {
  279. return []factor{{typ: IN, name: lookingFOr, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  280. }
  281. if _, alreadyComputed := mGatesUsed[lookingFOr]; alreadyComputed {
  282. return []factor{{typ: IN, name: lookingFOr, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  283. }
  284. leftFactors := p.collectFactors(contextCircut, node.left, mGatesUsed, negate, invert)
  285. rightFactors := p.collectFactors(contextCircut, node.right, mGatesUsed, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  286. switch node.OperationType() {
  287. case MULTIPLY:
  288. return mulFactors(leftFactors, rightFactors)
  289. case PLUS:
  290. return addFactors(leftFactors, rightFactors)
  291. default:
  292. panic("unexpected gate")
  293. }
  294. }
  295. //copies a gate neglecting its references to other gates
  296. func cloneGate(in *gate) (out *gate) {
  297. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  298. nRightins := make([]factor, len(in.rightIns))
  299. nLeftInst := make([]factor, len(in.leftIns))
  300. for k, v := range in.rightIns {
  301. nRightins[k] = v
  302. }
  303. for k, v := range in.leftIns {
  304. nLeftInst[k] = v
  305. }
  306. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  307. }
  308. func (p *Program) getMainCircuit() *Circuit {
  309. return p.functions["main"]
  310. }
  311. //func (p *Program) addGlobalInput(c Constraint) {
  312. // c.Out = "main@" + c.Out
  313. // p.globalInputs = append(p.globalInputs, c)
  314. //}
  315. func prepareUtils() utils {
  316. bn, err := bn128.NewBn128()
  317. if err != nil {
  318. panic(err)
  319. }
  320. // new Finite Field
  321. fqR := fields.NewFq(bn.R)
  322. // new Polynomial Field
  323. pf := r1csqap.NewPolynomialField(fqR)
  324. return utils{
  325. Bn: bn,
  326. FqR: fqR,
  327. PF: pf,
  328. }
  329. }
  330. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  331. if constraint.Op != FUNC {
  332. panic("not a function")
  333. }
  334. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  335. // panic("constraint must be within the contextCircuit circuit")
  336. //}
  337. b, n, _ := isFunction(constraint.Out)
  338. if !b {
  339. panic("not expected")
  340. }
  341. if newContext, v := p.functions[n]; v {
  342. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  343. for i, argument := range constraint.Inputs {
  344. isConst, _ := isValue(argument)
  345. if isConst {
  346. continue
  347. }
  348. isFunc, _, _ := isFunction(argument)
  349. if isFunc {
  350. panic("functions as arguments no supported yet")
  351. //p.extendedFunctionRenamer(contextCircuit,)
  352. }
  353. //at this point I assert that argument is a variable. This can become troublesome later
  354. //first we get the circuit in which the argument was created
  355. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  356. //we pick the gate that has the argument as output
  357. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  358. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  359. oldGate := newContext.gateMap[newContext.Inputs[i]]
  360. //we take the old gate which was nothing but a input
  361. //and link this input to its constituents coming from the calling contextCircuit.
  362. //i think this is pretty neat
  363. oldGate.value = gate.value
  364. oldGate.right = gate.right
  365. oldGate.left = gate.left
  366. } else {
  367. panic("not expected")
  368. }
  369. }
  370. //newContext.renameInputs(constraint.Inputs)
  371. return newContext
  372. }
  373. return nil
  374. }
  375. func NewProgram() (p *Program) {
  376. p = &Program{functions: map[string]*Circuit{}, globalInputs: []string{"one"}, arithmeticEnvironment: prepareUtils()}
  377. return
  378. }
  379. // GenerateR1CS generates the R1CS polynomials from the Circuit
  380. func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
  381. // from flat code to R1CS
  382. offset := len(p.globalInputs)
  383. // one + in1 +in2+... + gate1 + gate2 .. + out
  384. size := offset + len(mGates)
  385. indexMap := make(map[string]int)
  386. for i, v := range p.globalInputs {
  387. indexMap[v] = i
  388. }
  389. for i, v := range mGates {
  390. indexMap[v.value.Out] = i + offset
  391. }
  392. for _, gate := range mGates {
  393. if gate.OperationType() == MULTIPLY {
  394. aConstraint := r1csqap.ArrayOfBigZeros(size)
  395. bConstraint := r1csqap.ArrayOfBigZeros(size)
  396. cConstraint := r1csqap.ArrayOfBigZeros(size)
  397. for _, val := range gate.leftIns {
  398. if val.typ != CONST {
  399. if _, ex := indexMap[val.name]; !ex {
  400. panic(fmt.Sprintf("%v index not found!!!", val.name))
  401. }
  402. }
  403. convertAndInsertFactorAt(aConstraint, val, indexMap[val.name])
  404. }
  405. for _, val := range gate.rightIns {
  406. if val.typ != CONST {
  407. if _, ex := indexMap[val.name]; !ex {
  408. panic(fmt.Sprintf("%v index not found!!!", val.name))
  409. }
  410. }
  411. convertAndInsertFactorAt(bConstraint, val, indexMap[val.name])
  412. }
  413. cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
  414. if gate.value.invert {
  415. tmp := aConstraint
  416. aConstraint = cConstraint
  417. cConstraint = tmp
  418. }
  419. a = append(a, aConstraint)
  420. b = append(b, bConstraint)
  421. c = append(c, cConstraint)
  422. } else {
  423. panic("not a m gate")
  424. }
  425. }
  426. p.R1CS.A = a
  427. p.R1CS.B = b
  428. p.R1CS.C = c
  429. return a, b, c
  430. }
  431. var Utils = prepareUtils()
  432. func fractionToField(in [2]int) *big.Int {
  433. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  434. }
  435. func convertAndInsertFactorAt(arr []*big.Int, val factor, index int) {
  436. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  437. if val.negate {
  438. value.Neg(value)
  439. }
  440. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  441. arr[index] = value
  442. }
  443. func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
  444. if len(p.globalInputs)-1 != len(input) {
  445. panic("input do not match the required inputs")
  446. }
  447. witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
  448. set := make([]bool, len(witness))
  449. witness[0] = big.NewInt(int64(1))
  450. set[0] = true
  451. for i := range input {
  452. witness[i+1] = input[i]
  453. set[i+1] = true
  454. }
  455. zero := big.NewInt(int64(0))
  456. for i := 0; i < len(p.R1CS.A); i++ {
  457. gatesLeftInputs := p.R1CS.A[i]
  458. gatesRightInputs := p.R1CS.B[i]
  459. gatesOutputs := p.R1CS.C[i]
  460. sumLeft := big.NewInt(int64(0))
  461. sumRight := big.NewInt(int64(0))
  462. sumOut := big.NewInt(int64(0))
  463. index := -1
  464. division := false
  465. for j, val := range gatesLeftInputs {
  466. if val.Cmp(zero) != 0 {
  467. if !set[j] {
  468. index = j
  469. division = true
  470. break
  471. }
  472. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  473. }
  474. }
  475. for j, val := range gatesRightInputs {
  476. if val.Cmp(zero) != 0 {
  477. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  478. }
  479. }
  480. for j, val := range gatesOutputs {
  481. if val.Cmp(zero) != 0 {
  482. if !set[j] {
  483. if index != -1 {
  484. panic("invalid R1CS form")
  485. }
  486. index = j
  487. break
  488. }
  489. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  490. }
  491. }
  492. if !division {
  493. set[index] = true
  494. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  495. } else {
  496. b := sumRight.Int64()
  497. c := sumOut.Int64()
  498. set[index] = true
  499. witness[index] = big.NewInt(c / b)
  500. }
  501. }
  502. return
  503. }