You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

580 lines
16 KiB

  1. package circuitcompiler
  2. import (
  3. "fmt"
  4. "github.com/mottla/go-snark/bn128"
  5. "github.com/mottla/go-snark/fields"
  6. "github.com/mottla/go-snark/r1csqap"
  7. "math/big"
  8. )
  9. type utils struct {
  10. Bn bn128.Bn128
  11. FqR fields.Fq
  12. PF r1csqap.PolynomialField
  13. }
  14. type Program struct {
  15. functions map[string]*Circuit
  16. globalInputs []Constraint
  17. arithmeticEnvironment utils //find a better name
  18. extendedFunctionRenamer func(context *Circuit, c Constraint) (newContext *Circuit)
  19. R1CS struct {
  20. A [][]*big.Int
  21. B [][]*big.Int
  22. C [][]*big.Int
  23. }
  24. }
  25. func (p *Program) PrintContraintTrees() {
  26. for k, v := range p.functions {
  27. fmt.Println(k)
  28. PrintTree(v.root)
  29. }
  30. }
  31. func (p *Program) BuildConstraintTrees() {
  32. functionRootMap := make(map[string]*gate)
  33. for _, circuit := range p.functions {
  34. //circuit.addConstraint(p.oneConstraint())
  35. fName := composeNewFunction(circuit.Name, circuit.Inputs)
  36. root := circuit.gateMap[fName]
  37. functionRootMap[fName] = root
  38. circuit.root = root
  39. circuit.buildTree(root)
  40. }
  41. return
  42. }
  43. func (c *Circuit) buildTree(g *gate) {
  44. if _, ex := c.gateMap[g.value.Out]; ex {
  45. if g.OperationType()&(IN|CONST) != 0 {
  46. return
  47. }
  48. } else {
  49. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  50. }
  51. if g.OperationType() == FUNC {
  52. //g.funcInputs = []*gate{}
  53. for _, in := range g.value.Inputs {
  54. if gate, ex := c.gateMap[in]; ex {
  55. //sadf
  56. g.funcInputs = append(g.funcInputs, gate)
  57. //note that we do repeated work here. the argument
  58. c.buildTree(gate)
  59. } else {
  60. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  61. }
  62. }
  63. return
  64. }
  65. if constr, ex := c.gateMap[g.value.V1]; ex {
  66. g.left = constr
  67. c.buildTree(g.left)
  68. } else {
  69. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  70. }
  71. if constr, ex := c.gateMap[g.value.V2]; ex {
  72. g.right = constr
  73. c.buildTree(g.right)
  74. } else {
  75. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  76. }
  77. }
  78. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  79. mGatesUsed := make(map[string]bool)
  80. orderedmGates = []gate{}
  81. functionRootMap := make(map[string]*gate)
  82. for k, v := range p.functions {
  83. functionRootMap[k] = v.root
  84. }
  85. p.extendedFunctionRenamer = func(context *Circuit, c Constraint) (nextContext *Circuit) {
  86. if c.Op != FUNC {
  87. panic("not a function")
  88. }
  89. if _, ex := context.gateMap[c.Out]; !ex {
  90. panic("constraint mus be within the context circuit")
  91. }
  92. if b, name, in := isFunction(c.Out); b {
  93. if newContext, v := p.functions[name]; v {
  94. //fmt.Println("unrenamed thing")
  95. //PrintTree(k.root)
  96. for i, argument := range in {
  97. if gate, ex := context.gateMap[argument]; ex {
  98. oldGate := newContext.gateMap[newContext.Inputs[i]]
  99. //we take the old gate which was nothing but a input
  100. //and link this input to its constituents comming from the calling context.
  101. //i think this is pretty neat
  102. oldGate.value = gate.value
  103. oldGate.right = gate.right
  104. oldGate.left = gate.left
  105. } else {
  106. panic("not expected")
  107. }
  108. }
  109. newContext.renameInputs(in)
  110. //fmt.Println("renamed thing")
  111. //PrintTree(k.root)
  112. return newContext
  113. }
  114. }
  115. panic("not a function dude")
  116. return nil
  117. }
  118. //traverseCombinedMultiplicationGates(p.getMainCircut().root, mGatesUsed, &orderedmGates, functionRootMap, functionRenamer, false, false)
  119. //markMgates(p.getMainCircut().root, mGatesUsed, &orderedmGates, functionRenamer, false, false)
  120. p.markMgates2(p.getMainCircut(), p.getMainCircut().root, mGatesUsed, &orderedmGates, false, false)
  121. return orderedmGates
  122. }
  123. func (p *Program) markMgates2(contextCircut *Circuit, root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, negate bool, inverse bool) (isConstant bool) {
  124. if root.OperationType() == IN {
  125. return false
  126. }
  127. if root.OperationType() == CONST {
  128. return true
  129. }
  130. if root.OperationType() == FUNC {
  131. nextContext := p.extendedFunctionRenamer(contextCircut, root.value)
  132. isConstant = p.markMgates2(nextContext, nextContext.root, mGatesUsed, orderedmGates, negate, inverse)
  133. } else {
  134. if _, alreadyComputed := mGatesUsed[root.value.V1]; !alreadyComputed {
  135. isConstant = p.markMgates2(contextCircut, root.left, mGatesUsed, orderedmGates, negate, inverse)
  136. }
  137. if _, alreadyComputed := mGatesUsed[root.value.V2]; !alreadyComputed {
  138. cons := p.markMgates2(contextCircut, root.right, mGatesUsed, orderedmGates, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  139. isConstant = isConstant || cons
  140. }
  141. }
  142. if root.OperationType() == MULTIPLY {
  143. _, n, _ := isFunction(root.value.Out)
  144. if isConstant && !root.value.invert && n != "main" {
  145. return false
  146. }
  147. root.leftIns = p.collectAtomsInSubtree2(contextCircut, root.left, mGatesUsed, false, false)
  148. //if root.left.value.Out== root.right.value.Out{
  149. // //note this is not a full copy, but shouldnt be a problem
  150. // root.rightIns= root.leftIns
  151. //}else{
  152. // collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  153. //}
  154. //root.rightIns = collectAtomsInSubtree3(root.right, mGatesUsed, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
  155. root.rightIns = p.collectAtomsInSubtree2(contextCircut, root.right, mGatesUsed, false, false)
  156. root.index = len(mGatesUsed)
  157. mGatesUsed[root.value.Out] = true
  158. rootGate := cloneGate(root)
  159. *orderedmGates = append(*orderedmGates, *rootGate)
  160. }
  161. return isConstant
  162. //TODO optimize if output is not a multipication gate
  163. }
  164. type factor struct {
  165. typ Token
  166. name string
  167. invert, negate bool
  168. multiplicative [2]int
  169. }
  170. func (f factor) String() string {
  171. if f.typ == CONST {
  172. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  173. }
  174. str := f.name
  175. if f.invert {
  176. str += "^-1"
  177. }
  178. if f.negate {
  179. str = "-" + str
  180. }
  181. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  182. }
  183. func mul2DVector(a, b [2]int) [2]int {
  184. return [2]int{a[0] * b[0], a[1] * b[1]}
  185. }
  186. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  187. for _, facLeft := range leftFactors {
  188. for i, facRight := range rightFactors {
  189. if facLeft.typ == CONST && facRight.typ == IN {
  190. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  191. continue
  192. }
  193. if facRight.typ == CONST && facLeft.typ == IN {
  194. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  195. continue
  196. }
  197. if facRight.typ&facLeft.typ == CONST {
  198. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  199. continue
  200. }
  201. //tricky part here
  202. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  203. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  204. if facRight.typ&facLeft.typ == IN {
  205. //if facRight.n
  206. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  207. //continue
  208. }
  209. panic("unexpected")
  210. }
  211. }
  212. return rightFactors
  213. }
  214. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  215. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  216. func abs(n int) (val int, positive bool) {
  217. y := n >> 63
  218. return (n ^ y) - y, y == 0
  219. }
  220. //returns the reduced sum of two input factor arrays
  221. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  222. func addFactors(leftFactors, rightFactors []factor) (res []factor) {
  223. var found bool
  224. for _, facLeft := range leftFactors {
  225. for i, facRight := range rightFactors {
  226. if facLeft.typ&facRight.typ == CONST {
  227. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  228. if facLeft.negate {
  229. a0 *= -1
  230. }
  231. if facRight.negate {
  232. b0 *= -1
  233. }
  234. absValue, negate := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  235. rightFactors[i] = factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  236. found = true
  237. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  238. break
  239. }
  240. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  241. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  242. if facLeft.negate {
  243. a0 *= -1
  244. }
  245. if facRight.negate {
  246. b0 *= -1
  247. }
  248. absValue, negate := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  249. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  250. found = true
  251. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  252. break
  253. }
  254. }
  255. if !found {
  256. res = append(res, facLeft)
  257. found = false
  258. }
  259. }
  260. return append(res, rightFactors...)
  261. }
  262. func (p *Program) collectAtomsInSubtree2(contextCircut *Circuit, g *gate, mGatesUsed map[string]bool, negate bool, invert bool) []factor {
  263. if _, ex := mGatesUsed[g.value.Out]; ex {
  264. return []factor{{typ: IN, name: g.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  265. }
  266. if g.OperationType() == IN {
  267. return []factor{{typ: IN, name: g.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
  268. }
  269. if g.OperationType() == FUNC {
  270. nextContext := p.extendedFunctionRenamer(contextCircut, g.value)
  271. return p.collectAtomsInSubtree2(nextContext, nextContext.root, mGatesUsed, negate, invert)
  272. }
  273. if g.OperationType() == CONST {
  274. b1, v1 := isValue(g.value.Out)
  275. if !b1 {
  276. panic("not a constant")
  277. }
  278. if invert {
  279. return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{1, v1}}}
  280. }
  281. return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{v1, 1}}}
  282. }
  283. var leftFactors, rightFactors []factor
  284. if g.left.OperationType() == FUNC {
  285. nextContext := p.extendedFunctionRenamer(contextCircut, g.left.value)
  286. leftFactors = p.collectAtomsInSubtree2(nextContext, nextContext.root, mGatesUsed, negate, invert)
  287. } else {
  288. leftFactors = p.collectAtomsInSubtree2(contextCircut, g.left, mGatesUsed, negate, invert)
  289. }
  290. if g.right.OperationType() == FUNC {
  291. nextContext := p.extendedFunctionRenamer(contextCircut, g.right.value)
  292. rightFactors = p.collectAtomsInSubtree2(nextContext, nextContext.root, mGatesUsed, Xor(negate, g.value.negate), Xor(invert, g.value.invert))
  293. } else {
  294. rightFactors = p.collectAtomsInSubtree2(contextCircut, g.right, mGatesUsed, Xor(negate, g.value.negate), Xor(invert, g.value.invert))
  295. }
  296. switch g.OperationType() {
  297. case MULTIPLY:
  298. return mulFactors(leftFactors, rightFactors)
  299. case PLUS:
  300. return addFactors(leftFactors, rightFactors)
  301. default:
  302. panic("unexpected gate")
  303. }
  304. }
  305. //copies a gate neglecting its references to other gates
  306. func cloneGate(in *gate) (out *gate) {
  307. constr := Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  308. nRightins := make([]factor, len(in.rightIns))
  309. nLeftInst := make([]factor, len(in.leftIns))
  310. for k, v := range in.rightIns {
  311. nRightins[k] = v
  312. }
  313. for k, v := range in.leftIns {
  314. nLeftInst[k] = v
  315. }
  316. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  317. }
  318. func (p *Program) getMainCircut() *Circuit {
  319. return p.functions["main"]
  320. }
  321. func (p *Program) addGlobalInput(c Constraint) {
  322. p.globalInputs = append(p.globalInputs, c)
  323. }
  324. func prepareUtils() utils {
  325. bn, err := bn128.NewBn128()
  326. if err != nil {
  327. panic(err)
  328. }
  329. // new Finite Field
  330. fqR := fields.NewFq(bn.R)
  331. // new Polynomial Field
  332. pf := r1csqap.NewPolynomialField(fqR)
  333. return utils{
  334. Bn: bn,
  335. FqR: fqR,
  336. PF: pf,
  337. }
  338. }
  339. func NewProgramm() *Program {
  340. //return &Program{functions: map[string]*Circuit{}, signals: []string{}, globalInputs: []*Constraint{{Op: PLUS, V1:"1",V2:"0", Out: "one"}}}
  341. return &Program{functions: map[string]*Circuit{}, globalInputs: []Constraint{{Op: IN, Out: "one"}}, arithmeticEnvironment: prepareUtils()}
  342. }
  343. func (p *Program) addFunction(constraint *Constraint) (c *Circuit) {
  344. name := constraint.Out
  345. fmt.Println("try to add function ", name)
  346. b, name2, _ := isFunction(name)
  347. if !b {
  348. panic(fmt.Sprintf("not a function: %v", constraint))
  349. }
  350. name = name2
  351. if _, ex := p.functions[name]; ex {
  352. panic("function already declared")
  353. }
  354. c = newCircuit(name)
  355. p.functions[name] = c
  356. //I need the inputs to be defined as input constraints for each function for later renaming conventions
  357. //if constraint.Literal == "main" {
  358. for _, in := range constraint.Inputs {
  359. newConstr := Constraint{
  360. Op: IN,
  361. Out: in,
  362. }
  363. if name == "main" {
  364. p.addGlobalInput(newConstr)
  365. }
  366. c.addConstraint(newConstr)
  367. }
  368. //}
  369. c.Inputs = constraint.Inputs
  370. return
  371. }
  372. // GenerateR1CS generates the R1CS polynomials from the Circuit
  373. func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
  374. // from flat code to R1CS
  375. offset := len(p.globalInputs)
  376. // one + in1 +in2+... + gate1 + gate2 .. + out
  377. size := offset + len(mGates)
  378. indexMap := make(map[string]int)
  379. for i, v := range p.globalInputs {
  380. indexMap[v.Out] = i
  381. }
  382. for i, v := range mGates {
  383. indexMap[v.value.Out] = i + offset
  384. }
  385. for _, gate := range mGates {
  386. if gate.OperationType() == MULTIPLY {
  387. aConstraint := r1csqap.ArrayOfBigZeros(size)
  388. bConstraint := r1csqap.ArrayOfBigZeros(size)
  389. cConstraint := r1csqap.ArrayOfBigZeros(size)
  390. for _, val := range gate.leftIns {
  391. convertAndInsertFactorAt(aConstraint, val, indexMap[val.name])
  392. }
  393. for _, val := range gate.rightIns {
  394. convertAndInsertFactorAt(bConstraint, val, indexMap[val.name])
  395. }
  396. cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
  397. if gate.value.invert {
  398. tmp := aConstraint
  399. aConstraint = cConstraint
  400. cConstraint = tmp
  401. }
  402. a = append(a, aConstraint)
  403. b = append(b, bConstraint)
  404. c = append(c, cConstraint)
  405. } else {
  406. panic("not a m gate")
  407. }
  408. }
  409. p.R1CS.A = a
  410. p.R1CS.B = b
  411. p.R1CS.C = c
  412. return a, b, c
  413. }
  414. var Utils = prepareUtils()
  415. func fractionToField(in [2]int) *big.Int {
  416. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  417. }
  418. func convertAndInsertFactorAt(arr []*big.Int, val factor, index int) {
  419. if val.typ == CONST {
  420. arr[0] = new(big.Int).Add(arr[0], fractionToField(val.multiplicative))
  421. return
  422. }
  423. arr[index] = new(big.Int).Add(arr[index], fractionToField(val.multiplicative))
  424. }
  425. func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
  426. if len(p.globalInputs)-1 != len(input) {
  427. panic("input do not match the required inputs")
  428. }
  429. witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
  430. set := make([]bool, len(witness))
  431. witness[0] = big.NewInt(int64(1))
  432. set[0] = true
  433. for i := range input {
  434. witness[i+1] = input[i]
  435. set[i+1] = true
  436. }
  437. zero := big.NewInt(int64(0))
  438. for i := 0; i < len(p.R1CS.A); i++ {
  439. gatesLeftInputs := p.R1CS.A[i]
  440. gatesRightInputs := p.R1CS.B[i]
  441. gatesOutputs := p.R1CS.C[i]
  442. sumLeft := big.NewInt(int64(0))
  443. sumRight := big.NewInt(int64(0))
  444. sumOut := big.NewInt(int64(0))
  445. index := -1
  446. division := false
  447. for j, val := range gatesLeftInputs {
  448. if val.Cmp(zero) != 0 {
  449. if !set[j] {
  450. index = j
  451. division = true
  452. break
  453. }
  454. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  455. }
  456. }
  457. for j, val := range gatesRightInputs {
  458. if val.Cmp(zero) != 0 {
  459. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  460. }
  461. }
  462. for j, val := range gatesOutputs {
  463. if val.Cmp(zero) != 0 {
  464. if !set[j] {
  465. if index != -1 {
  466. panic("invalid R1CS form")
  467. }
  468. index = j
  469. break
  470. }
  471. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  472. }
  473. }
  474. if !division {
  475. set[index] = true
  476. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  477. } else {
  478. b := sumRight.Int64()
  479. c := sumOut.Int64()
  480. set[index] = true
  481. witness[index] = big.NewInt(c / b)
  482. }
  483. }
  484. return
  485. }