You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

582 lines
16 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/mottla/go-snark/bn128"
  5. "github.com/mottla/go-snark/fields"
  6. "github.com/mottla/go-snark/r1csqap"
  7. "math/big"
  8. "sync"
  9. )
  10. type utils struct {
  11. Bn bn128.Bn128
  12. FqR fields.Fq
  13. PF r1csqap.PolynomialField
  14. }
  15. type Program struct {
  16. functions map[string]*Circuit
  17. globalInputs []Constraint
  18. arithmeticEnvironment utils //find a better name
  19. R1CS struct {
  20. A [][]*big.Int
  21. B [][]*big.Int
  22. C [][]*big.Int
  23. }
  24. }
  25. func (p *Program) PrintContraintTrees() {
  26. for k, v := range p.functions {
  27. fmt.Println(k)
  28. PrintTree(v.root)
  29. }
  30. }
  31. func (p *Program) BuildConstraintTrees() {
  32. mainRoot := p.getMainCircuit().root
  33. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  34. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  35. p.getMainCircuit().addConstraint(&newOut)
  36. mainRoot.value.Out = "main@out2"
  37. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  38. }
  39. var wg = sync.WaitGroup{}
  40. for _, circuit := range p.functions {
  41. wg.Add(1)
  42. func() {
  43. circuit.buildTree(circuit.root)
  44. wg.Done()
  45. }()
  46. }
  47. wg.Wait()
  48. return
  49. }
  50. func (c *Circuit) buildTree(g *gate) {
  51. if _, ex := c.gateMap[g.value.Out]; ex {
  52. if g.OperationType()&(IN|CONST) != 0 {
  53. return
  54. }
  55. } else {
  56. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  57. }
  58. if g.OperationType() == FUNC {
  59. //g.funcInputs = []*gate{}
  60. for _, in := range g.value.Inputs {
  61. if gate, ex := c.gateMap[in]; ex {
  62. g.funcInputs = append(g.funcInputs, gate)
  63. //note that we do repeated work here. the argument
  64. c.buildTree(gate)
  65. } else {
  66. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  67. }
  68. }
  69. return
  70. }
  71. if constr, ex := c.gateMap[g.value.V1]; ex {
  72. g.left = constr
  73. c.buildTree(g.left)
  74. } else {
  75. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  76. }
  77. if constr, ex := c.gateMap[g.value.V2]; ex {
  78. g.right = constr
  79. c.buildTree(g.right)
  80. } else {
  81. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  82. }
  83. }
  84. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  85. mGatesUsed := make(map[string]bool)
  86. orderedmGates = []gate{}
  87. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, mGatesUsed, &orderedmGates, false, false)
  88. return orderedmGates
  89. }
  90. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, negate bool, inverse bool) (isConstant bool) {
  91. if root.OperationType() == IN {
  92. return false
  93. }
  94. if root.OperationType() == CONST {
  95. return true
  96. }
  97. if _, alreadyComputed := mGatesUsed[root.value.Out]; alreadyComputed {
  98. return false
  99. }
  100. if root.OperationType() == FUNC {
  101. nextContext := p.extendedFunctionRenamer(currentCircuit, root.value)
  102. isConstant = p.r1CSRecursiveBuild(nextContext, nextContext.root, mGatesUsed, orderedmGates, negate, inverse)
  103. return isConstant
  104. }
  105. if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed {
  106. isConstant = p.r1CSRecursiveBuild(currentCircuit, root.left, mGatesUsed, orderedmGates, negate, inverse)
  107. }
  108. if _, alreadyComputed := mGatesUsed[root.value.V2]; !alreadyComputed {
  109. cons := p.r1CSRecursiveBuild(currentCircuit, root.right, mGatesUsed, orderedmGates, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  110. isConstant = isConstant || cons
  111. }
  112. if root.OperationType() == MULTIPLY {
  113. if isConstant && !root.value.invert && root != p.getMainCircuit().root {
  114. return false
  115. }
  116. root.leftIns = p.collectFactors(currentCircuit, root.left, mGatesUsed, false, false)
  117. //if root.left.value.Out== root.right.value.Out{
  118. // //note this is not a full copy, but shouldnt be a problem
  119. // root.rightIns= root.leftIns
  120. //}else{
  121. // collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  122. //}
  123. //root.rightIns = collectAtomsInSubtree3(root.right, mGatesUsed, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  124. root.rightIns = p.collectFactors(currentCircuit, root.right, mGatesUsed, false, false)
  125. root.index = len(mGatesUsed)
  126. var nn = root.value.Out
  127. //if _, ex := p.functions[nn]; ex {
  128. // nn = composeNewFunction(root.value.Out, currentCircuit.Inputs)
  129. //}
  130. mGatesUsed[nn] = true
  131. rootGate := cloneGate(root)
  132. rootGate.value.Out = nn
  133. *orderedmGates = append(*orderedmGates, *rootGate)
  134. }
  135. return isConstant
  136. //TODO optimize if output is not a multipication gate
  137. }
  138. type factor struct {
  139. typ Token
  140. name string
  141. invert, negate bool
  142. multiplicative [2]int
  143. }
  144. func (f factor) String() string {
  145. if f.typ == CONST {
  146. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  147. }
  148. str := f.name
  149. if f.invert {
  150. str += "^-1"
  151. }
  152. if f.negate {
  153. str = "-" + str
  154. }
  155. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  156. }
  157. func mul2DVector(a, b [2]int) [2]int {
  158. return [2]int{a[0] * b[0], a[1] * b[1]}
  159. }
  160. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  161. for _, facLeft := range leftFactors {
  162. for i, facRight := range rightFactors {
  163. if facLeft.typ == CONST && facRight.typ == IN {
  164. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  165. continue
  166. }
  167. if facRight.typ == CONST && facLeft.typ == IN {
  168. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  169. continue
  170. }
  171. if facRight.typ&facLeft.typ == CONST {
  172. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  173. continue
  174. }
  175. //tricky part here
  176. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  177. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  178. if facRight.typ&facLeft.typ == IN {
  179. //if facRight.n
  180. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  181. //continue
  182. }
  183. panic("unexpected")
  184. }
  185. }
  186. return rightFactors
  187. }
  188. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  189. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  190. func abs(n int) (val int, positive bool) {
  191. y := n >> 63
  192. return (n ^ y) - y, y == 0
  193. }
  194. //returns the reduced sum of two input factor arrays
  195. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  196. func addFactors(leftFactors, rightFactors []factor) []factor {
  197. var found bool
  198. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  199. for _, facLeft := range leftFactors {
  200. found = false
  201. for i, facRight := range rightFactors {
  202. if facLeft.typ&facRight.typ == CONST {
  203. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  204. if facLeft.negate {
  205. a0 *= -1
  206. }
  207. if facRight.negate {
  208. b0 *= -1
  209. }
  210. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  211. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  212. found = true
  213. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  214. break
  215. }
  216. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  217. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  218. if facLeft.negate {
  219. a0 *= -1
  220. }
  221. if facRight.negate {
  222. b0 *= -1
  223. }
  224. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  225. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  226. found = true
  227. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  228. break
  229. }
  230. }
  231. if !found {
  232. res = append(res, facLeft)
  233. }
  234. }
  235. for _, val := range rightFactors {
  236. if val.multiplicative[0] != 0 {
  237. res = append(res, val)
  238. }
  239. }
  240. return res
  241. }
  242. func (p *Program) collectFactors(contextCircut *Circuit, g *gate, mGatesUsed map[string]bool, negate bool, invert bool) []factor {
  243. if _, ex := mGatesUsed[g.value.Out]; ex {
  244. return []factor{{typ: IN, name: g.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  245. }
  246. if g.OperationType() == IN {
  247. return []factor{{typ: IN, name: g.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  248. }
  249. if g.OperationType() == FUNC {
  250. nextContext := p.extendedFunctionRenamer(contextCircut, g.value)
  251. return p.collectFactors(nextContext, nextContext.root, mGatesUsed, negate, invert)
  252. }
  253. if g.OperationType() == CONST {
  254. b1, v1 := isValue(g.value.Out)
  255. if !b1 {
  256. panic("not a constant")
  257. }
  258. if invert {
  259. return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{1, v1}}}
  260. }
  261. return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{v1, 1}}}
  262. }
  263. var leftFactors, rightFactors []factor
  264. if g.left.OperationType() == FUNC {
  265. nextContext := p.extendedFunctionRenamer(contextCircut, g.left.value)
  266. leftFactors = p.collectFactors(nextContext, nextContext.root, mGatesUsed, negate, invert)
  267. } else {
  268. leftFactors = p.collectFactors(contextCircut, g.left, mGatesUsed, negate, invert)
  269. }
  270. if g.right.OperationType() == FUNC {
  271. nextContext := p.extendedFunctionRenamer(contextCircut, g.right.value)
  272. rightFactors = p.collectFactors(nextContext, nextContext.root, mGatesUsed, Xor(negate, g.value.negate), Xor(invert, g.value.invert))
  273. } else {
  274. rightFactors = p.collectFactors(contextCircut, g.right, mGatesUsed, Xor(negate, g.value.negate), Xor(invert, g.value.invert))
  275. }
  276. switch g.OperationType() {
  277. case MULTIPLY:
  278. return mulFactors(leftFactors, rightFactors)
  279. case PLUS:
  280. return addFactors(leftFactors, rightFactors)
  281. default:
  282. panic("unexpected gate")
  283. }
  284. }
  285. //copies a gate neglecting its references to other gates
  286. func cloneGate(in *gate) (out *gate) {
  287. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  288. nRightins := make([]factor, len(in.rightIns))
  289. nLeftInst := make([]factor, len(in.leftIns))
  290. for k, v := range in.rightIns {
  291. nRightins[k] = v
  292. }
  293. for k, v := range in.leftIns {
  294. nLeftInst[k] = v
  295. }
  296. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  297. }
  298. func (p *Program) getMainCircuit() *Circuit {
  299. return p.functions["main"]
  300. }
  301. func (p *Program) addGlobalInput(c Constraint) {
  302. c.Out = "main@" + c.Out
  303. p.globalInputs = append(p.globalInputs, c)
  304. }
  305. func prepareUtils() utils {
  306. bn, err := bn128.NewBn128()
  307. if err != nil {
  308. panic(err)
  309. }
  310. // new Finite Field
  311. fqR := fields.NewFq(bn.R)
  312. // new Polynomial Field
  313. pf := r1csqap.NewPolynomialField(fqR)
  314. return utils{
  315. Bn: bn,
  316. FqR: fqR,
  317. PF: pf,
  318. }
  319. }
  320. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  321. if constraint.Op != FUNC {
  322. panic("not a function")
  323. }
  324. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  325. // panic("constraint must be within the contextCircuit circuit")
  326. //}
  327. if b, n, _ := isFunction(constraint.Out); b {
  328. if newContext, v := p.functions[n]; v {
  329. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  330. for i, argument := range constraint.Inputs {
  331. isConst, _ := isValue(argument)
  332. if isConst {
  333. continue
  334. }
  335. isFunc, _, _ := isFunction(argument)
  336. if isFunc {
  337. panic("functions as arguments no supported yet")
  338. //p.extendedFunctionRenamer(contextCircuit,)
  339. }
  340. //at this point I assert that argument is a variable. This can become troublesome later
  341. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  342. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  343. oldGate := newContext.gateMap[newContext.Inputs[i]]
  344. //we take the old gate which was nothing but a input
  345. //and link this input to its constituents comming from the calling contextCircuit.
  346. //i think this is pretty neat
  347. oldGate.value = gate.value
  348. oldGate.right = gate.right
  349. oldGate.left = gate.left
  350. } else {
  351. panic("not expected")
  352. }
  353. }
  354. newContext.renameInputs(constraint.Inputs)
  355. return newContext
  356. }
  357. } else {
  358. panic("not expected")
  359. }
  360. return nil
  361. }
  362. func NewProgram() (p *Program) {
  363. p = &Program{functions: map[string]*Circuit{}, globalInputs: []Constraint{{Op: IN, Out: "one"}}, arithmeticEnvironment: prepareUtils()}
  364. return
  365. }
  366. // GenerateR1CS generates the R1CS polynomials from the Circuit
  367. func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
  368. // from flat code to R1CS
  369. offset := len(p.globalInputs)
  370. // one + in1 +in2+... + gate1 + gate2 .. + out
  371. size := offset + len(mGates)
  372. indexMap := make(map[string]int)
  373. for i, v := range p.globalInputs {
  374. indexMap[v.Out] = i
  375. }
  376. for i, v := range mGates {
  377. indexMap[v.value.Out] = i + offset
  378. }
  379. for _, gate := range mGates {
  380. if gate.OperationType() == MULTIPLY {
  381. aConstraint := r1csqap.ArrayOfBigZeros(size)
  382. bConstraint := r1csqap.ArrayOfBigZeros(size)
  383. cConstraint := r1csqap.ArrayOfBigZeros(size)
  384. for _, val := range gate.leftIns {
  385. convertAndInsertFactorAt(aConstraint, val, indexMap[val.name])
  386. }
  387. for _, val := range gate.rightIns {
  388. convertAndInsertFactorAt(bConstraint, val, indexMap[val.name])
  389. }
  390. cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
  391. if gate.value.invert {
  392. tmp := aConstraint
  393. aConstraint = cConstraint
  394. cConstraint = tmp
  395. }
  396. a = append(a, aConstraint)
  397. b = append(b, bConstraint)
  398. c = append(c, cConstraint)
  399. } else {
  400. panic("not a m gate")
  401. }
  402. }
  403. p.R1CS.A = a
  404. p.R1CS.B = b
  405. p.R1CS.C = c
  406. return a, b, c
  407. }
  408. var Utils = prepareUtils()
  409. func fractionToField(in [2]int) *big.Int {
  410. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  411. }
  412. func convertAndInsertFactorAt(arr []*big.Int, val factor, index int) {
  413. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  414. if val.negate {
  415. value.Neg(value)
  416. }
  417. if val.typ == CONST {
  418. arr[0] = value
  419. } else {
  420. arr[index] = value
  421. }
  422. }
  423. func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
  424. if len(p.globalInputs)-1 != len(input) {
  425. panic("input do not match the required inputs")
  426. }
  427. witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
  428. set := make([]bool, len(witness))
  429. witness[0] = big.NewInt(int64(1))
  430. set[0] = true
  431. for i := range input {
  432. witness[i+1] = input[i]
  433. set[i+1] = true
  434. }
  435. zero := big.NewInt(int64(0))
  436. for i := 0; i < len(p.R1CS.A); i++ {
  437. gatesLeftInputs := p.R1CS.A[i]
  438. gatesRightInputs := p.R1CS.B[i]
  439. gatesOutputs := p.R1CS.C[i]
  440. sumLeft := big.NewInt(int64(0))
  441. sumRight := big.NewInt(int64(0))
  442. sumOut := big.NewInt(int64(0))
  443. index := -1
  444. division := false
  445. for j, val := range gatesLeftInputs {
  446. if val.Cmp(zero) != 0 {
  447. if !set[j] {
  448. index = j
  449. division = true
  450. break
  451. }
  452. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  453. }
  454. }
  455. for j, val := range gatesRightInputs {
  456. if val.Cmp(zero) != 0 {
  457. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  458. }
  459. }
  460. for j, val := range gatesOutputs {
  461. if val.Cmp(zero) != 0 {
  462. if !set[j] {
  463. if index != -1 {
  464. panic("invalid R1CS form")
  465. }
  466. index = j
  467. break
  468. }
  469. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  470. }
  471. }
  472. if !division {
  473. set[index] = true
  474. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  475. } else {
  476. b := sumRight.Int64()
  477. c := sumOut.Int64()
  478. set[index] = true
  479. witness[index] = big.NewInt(c / b)
  480. }
  481. }
  482. return
  483. }