You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

609 lines
18 KiB

  1. package circuitcompiler
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "github.com/mottla/go-snark/bn128"
  6. "github.com/mottla/go-snark/fields"
  7. "github.com/mottla/go-snark/r1csqap"
  8. "hash"
  9. "math/big"
  10. "sync"
  11. )
  12. type utils struct {
  13. Bn bn128.Bn128
  14. FqR fields.Fq
  15. PF r1csqap.PolynomialField
  16. }
  17. type R1CS struct {
  18. A [][]*big.Int
  19. B [][]*big.Int
  20. C [][]*big.Int
  21. }
  22. type Program struct {
  23. functions map[string]*Circuit
  24. globalInputs []string
  25. arithmeticEnvironment utils //find a better name
  26. sha256Hasher hash.Hash
  27. //key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function.
  28. //value 1 : map
  29. // with key variable name
  30. // with value variable name + hash Chain
  31. //this datastructure is nice but maybe ill replace it later with something less confusing
  32. //it serves the elementary purpose of not computing a variable a second time
  33. computedInContext map[string]map[string]string
  34. //to reduce the number of multiplication gates, we store each factor signature, and the variable name,
  35. //so each time a variable is computed, that happens to have the very same factors, we reuse the former gate
  36. computedFactors map[string]string
  37. }
  38. func (p *Program) PrintContraintTrees() {
  39. for k, v := range p.functions {
  40. fmt.Println(k)
  41. PrintTree(v.root)
  42. }
  43. }
  44. func (p *Program) BuildConstraintTrees() {
  45. mainRoot := p.getMainCircuit().root
  46. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  47. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  48. p.getMainCircuit().addConstraint(&newOut)
  49. mainRoot.value.Out = "main@out2"
  50. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  51. }
  52. for _, in := range p.getMainCircuit().Inputs {
  53. p.globalInputs = append(p.globalInputs, in)
  54. }
  55. var wg = sync.WaitGroup{}
  56. for _, circuit := range p.functions {
  57. wg.Add(1)
  58. //interesting: if circuit is not passed as argument, the program fails. duno why..
  59. go func(c *Circuit) {
  60. c.buildTree(c.root)
  61. wg.Done()
  62. }(circuit)
  63. }
  64. wg.Wait()
  65. return
  66. }
  67. func (c *Circuit) buildTree(g *gate) {
  68. if _, ex := c.gateMap[g.value.Out]; ex {
  69. if g.OperationType()&(IN|CONST) != 0 {
  70. return
  71. }
  72. } else {
  73. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  74. }
  75. if g.OperationType() == FUNC {
  76. //g.funcInputs = []*gate{}
  77. for _, in := range g.value.Inputs {
  78. if gate, ex := c.gateMap[in]; ex {
  79. g.funcInputs = append(g.funcInputs, gate)
  80. //note that we do repeated work here. the argument
  81. c.buildTree(gate)
  82. } else {
  83. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  84. }
  85. }
  86. return
  87. }
  88. if constr, ex := c.gateMap[g.value.V1]; ex {
  89. g.left = constr
  90. c.buildTree(g.left)
  91. } else {
  92. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  93. }
  94. if constr, ex := c.gateMap[g.value.V2]; ex {
  95. g.right = constr
  96. c.buildTree(g.right)
  97. } else {
  98. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  99. }
  100. }
  101. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  102. //mGatesUsed := make(map[string]bool)
  103. orderedmGates = []gate{}
  104. p.computedInContext = make(map[string]map[string]string)
  105. p.computedFactors = make(map[string]string)
  106. rootHash := []byte{}
  107. p.computedInContext[string(rootHash)] = make(map[string]string)
  108. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
  109. return orderedmGates
  110. }
  111. //recursively walks through the parse tree to create a list of all
  112. //multiplication gates needed for the QAP construction
  113. //Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
  114. //TODO if the same variable that has been computed in context A, is needed again but from a different context b, will be recomputed and not reused
  115. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
  116. if node.OperationType() == CONST {
  117. b1, v1 := isValue(node.value.Out)
  118. if !b1 {
  119. panic("not a constant")
  120. }
  121. mul := [2]int{v1, 1}
  122. if invert {
  123. mul = [2]int{1, v1}
  124. }
  125. return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false
  126. }
  127. if node.OperationType() == FUNC {
  128. nextContext := p.extendedFunctionRenamer(currentCircuit, node.value)
  129. currentCircuit = nextContext
  130. node = nextContext.root
  131. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
  132. if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
  133. p.computedInContext[string(hashTraceBuildup)] = make(map[string]string)
  134. }
  135. }
  136. if node.OperationType() == IN {
  137. fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  138. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  139. return []factor{fac}, hashTraceBuildup, true
  140. }
  141. if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
  142. fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  143. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  144. return []factor{fac}, hashTraceBuildup, true
  145. }
  146. leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
  147. rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  148. if node.OperationType() == MULTIPLY {
  149. if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  150. //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  151. return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
  152. }
  153. sig := factorsSignature(leftFactors, rightFactors)
  154. if out, ex := p.computedFactors[sig]; ex {
  155. return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  156. }
  157. rootGate := cloneGate(node)
  158. rootGate.index = len(*orderedmGates)
  159. rootGate.leftIns = leftFactors
  160. rootGate.rightIns = rightFactors
  161. out := hashTogether(leftHash, rightHash)
  162. rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
  163. rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
  164. rootGate.value.Out = rootGate.value.Out + string(out[:10])
  165. p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
  166. p.computedFactors[sig] = rootGate.value.Out
  167. *orderedmGates = append(*orderedmGates, *rootGate)
  168. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
  169. return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  170. }
  171. switch node.OperationType() {
  172. case PLUS:
  173. return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons
  174. default:
  175. panic("unexpected gate")
  176. }
  177. }
  178. type factor struct {
  179. typ Token
  180. name string
  181. invert, negate bool
  182. multiplicative [2]int
  183. }
  184. func (f factor) String() string {
  185. if f.typ == CONST {
  186. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  187. }
  188. str := f.name
  189. if f.invert {
  190. str += "^-1"
  191. }
  192. if f.negate {
  193. str = "-" + str
  194. }
  195. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  196. }
  197. func mul2DVector(a, b [2]int) [2]int {
  198. return [2]int{a[0] * b[0], a[1] * b[1]}
  199. }
  200. func factorsSignature(leftFactors, rightFactors []factor) string {
  201. hasher.Reset()
  202. //not using a kommutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a)
  203. //could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard
  204. for _, facLeft := range leftFactors {
  205. hasher.Write([]byte(facLeft.String()))
  206. }
  207. for _, Righ := range rightFactors {
  208. hasher.Write([]byte(Righ.String()))
  209. }
  210. return string(hasher.Sum(nil))[:16]
  211. }
  212. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  213. for _, facLeft := range leftFactors {
  214. for i, facRight := range rightFactors {
  215. if facLeft.typ == CONST && facRight.typ == IN {
  216. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  217. continue
  218. }
  219. if facRight.typ == CONST && facLeft.typ == IN {
  220. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  221. continue
  222. }
  223. if facRight.typ&facLeft.typ == CONST {
  224. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  225. continue
  226. }
  227. //tricky part here
  228. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  229. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  230. if facRight.typ&facLeft.typ == IN {
  231. if facLeft.name == facRight.name {
  232. if facRight.invert != facLeft.invert {
  233. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  234. continue
  235. }
  236. }
  237. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  238. //continue
  239. }
  240. panic("unexpected")
  241. }
  242. }
  243. return rightFactors
  244. }
  245. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  246. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  247. func abs(n int) (val int, positive bool) {
  248. y := n >> 63
  249. return (n ^ y) - y, y == 0
  250. }
  251. //returns the reduced sum of two input factor arrays
  252. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  253. func addFactors(leftFactors, rightFactors []factor) []factor {
  254. var found bool
  255. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  256. for _, facLeft := range leftFactors {
  257. found = false
  258. for i, facRight := range rightFactors {
  259. if facLeft.typ&facRight.typ == CONST {
  260. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  261. if facLeft.negate {
  262. a0 *= -1
  263. }
  264. if facRight.negate {
  265. b0 *= -1
  266. }
  267. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  268. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  269. found = true
  270. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  271. break
  272. }
  273. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  274. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  275. if facLeft.negate {
  276. a0 *= -1
  277. }
  278. if facRight.negate {
  279. b0 *= -1
  280. }
  281. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  282. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  283. found = true
  284. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  285. break
  286. }
  287. }
  288. if !found {
  289. res = append(res, facLeft)
  290. }
  291. }
  292. for _, val := range rightFactors {
  293. if val.multiplicative[0] != 0 {
  294. res = append(res, val)
  295. }
  296. }
  297. return res
  298. }
  299. //copies a gate neglecting its references to other gates
  300. func cloneGate(in *gate) (out *gate) {
  301. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  302. nRightins := make([]factor, len(in.rightIns))
  303. nLeftInst := make([]factor, len(in.leftIns))
  304. for k, v := range in.rightIns {
  305. nRightins[k] = v
  306. }
  307. for k, v := range in.leftIns {
  308. nLeftInst[k] = v
  309. }
  310. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  311. }
  312. func (p *Program) getMainCircuit() *Circuit {
  313. return p.functions["main"]
  314. }
  315. func prepareUtils() utils {
  316. bn, err := bn128.NewBn128()
  317. if err != nil {
  318. panic(err)
  319. }
  320. // new Finite Field
  321. fqR := fields.NewFq(bn.R)
  322. // new Polynomial Field
  323. pf := r1csqap.NewPolynomialField(fqR)
  324. return utils{
  325. Bn: bn,
  326. FqR: fqR,
  327. PF: pf,
  328. }
  329. }
  330. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  331. if constraint.Op != FUNC {
  332. panic("not a function")
  333. }
  334. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  335. // panic("constraint must be within the contextCircuit circuit")
  336. //}
  337. b, n, _ := isFunction(constraint.Out)
  338. if !b {
  339. panic("not expected")
  340. }
  341. if newContext, v := p.functions[n]; v {
  342. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  343. for i, argument := range constraint.Inputs {
  344. isConst, _ := isValue(argument)
  345. if isConst {
  346. continue
  347. }
  348. isFunc, _, _ := isFunction(argument)
  349. if isFunc {
  350. panic("functions as arguments no supported yet")
  351. //p.extendedFunctionRenamer(contextCircuit,)
  352. }
  353. //at this point I assert that argument is a variable. This can become troublesome later
  354. //first we get the circuit in which the argument was created
  355. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  356. //we pick the gate that has the argument as output
  357. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  358. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  359. oldGate := newContext.gateMap[newContext.Inputs[i]]
  360. //we take the old gate which was nothing but a input
  361. //and link this input to its constituents coming from the calling contextCircuit.
  362. //i think this is pretty neat
  363. oldGate.value = gate.value
  364. oldGate.right = gate.right
  365. oldGate.left = gate.left
  366. } else {
  367. panic("not expected")
  368. }
  369. }
  370. //newContext.renameInputs(constraint.Inputs)
  371. return newContext
  372. }
  373. return nil
  374. }
  375. func NewProgram() (p *Program) {
  376. p = &Program{
  377. functions: map[string]*Circuit{},
  378. globalInputs: []string{"one"},
  379. arithmeticEnvironment: prepareUtils(),
  380. sha256Hasher: sha256.New(),
  381. }
  382. return
  383. }
  384. // GenerateR1CS generates the R1CS polynomials from the Circuit
  385. func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
  386. // from flat code to R1CS
  387. offset := len(p.globalInputs)
  388. // one + in1 +in2+... + gate1 + gate2 .. + out
  389. size := offset + len(mGates)
  390. indexMap := make(map[string]int)
  391. for i, v := range p.globalInputs {
  392. indexMap[v] = i
  393. }
  394. for i, v := range mGates {
  395. indexMap[v.value.Out] = i + offset
  396. }
  397. for _, g := range mGates {
  398. if g.OperationType() == MULTIPLY {
  399. aConstraint := r1csqap.ArrayOfBigZeros(size)
  400. bConstraint := r1csqap.ArrayOfBigZeros(size)
  401. cConstraint := r1csqap.ArrayOfBigZeros(size)
  402. insertValue := func(val factor, arr []*big.Int) {
  403. if val.typ != CONST {
  404. if _, ex := indexMap[val.name]; !ex {
  405. panic(fmt.Sprintf("%v index not found!!!", val.name))
  406. }
  407. }
  408. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  409. if val.negate {
  410. value.Neg(value)
  411. }
  412. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  413. arr[indexMap[val.name]] = value
  414. }
  415. for _, val := range g.leftIns {
  416. insertValue(val, aConstraint)
  417. }
  418. for _, val := range g.rightIns {
  419. insertValue(val, bConstraint)
  420. }
  421. cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1))
  422. if g.value.invert {
  423. tmp := aConstraint
  424. aConstraint = cConstraint
  425. cConstraint = tmp
  426. }
  427. r1CS.A = append(r1CS.A, aConstraint)
  428. r1CS.B = append(r1CS.B, bConstraint)
  429. r1CS.C = append(r1CS.C, cConstraint)
  430. } else {
  431. panic("not a m gate")
  432. }
  433. }
  434. return
  435. }
  436. var Utils = prepareUtils()
  437. func fractionToField(in [2]int) *big.Int {
  438. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  439. }
  440. //Calculates the witness (program trace) given some input
  441. //asserts that R1CS has been computed and is stored in the program p memory calling this function
  442. func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) {
  443. //if len(p.globalInputs)-1 != len(input) {
  444. // panic("input do not match the required inputs")
  445. //}
  446. witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0]))
  447. set := make([]bool, len(witness))
  448. witness[0] = big.NewInt(int64(1))
  449. set[0] = true
  450. for i := range input {
  451. witness[i+1] = input[i]
  452. set[i+1] = true
  453. }
  454. zero := big.NewInt(int64(0))
  455. for i := 0; i < len(r1cs.A); i++ {
  456. gatesLeftInputs := r1cs.A[i]
  457. gatesRightInputs := r1cs.B[i]
  458. gatesOutputs := r1cs.C[i]
  459. sumLeft := big.NewInt(int64(0))
  460. sumRight := big.NewInt(int64(0))
  461. sumOut := big.NewInt(int64(0))
  462. index := -1
  463. division := false
  464. for j, val := range gatesLeftInputs {
  465. if val.Cmp(zero) != 0 {
  466. if !set[j] {
  467. index = j
  468. division = true
  469. break
  470. }
  471. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  472. }
  473. }
  474. for j, val := range gatesRightInputs {
  475. if val.Cmp(zero) != 0 {
  476. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  477. }
  478. }
  479. for j, val := range gatesOutputs {
  480. if val.Cmp(zero) != 0 {
  481. if !set[j] {
  482. if index != -1 {
  483. panic("invalid R1CS form")
  484. }
  485. index = j
  486. break
  487. }
  488. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  489. }
  490. }
  491. if !division {
  492. set[index] = true
  493. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  494. } else {
  495. b := sumRight.Int64()
  496. c := sumOut.Int64()
  497. set[index] = true
  498. witness[index] = big.NewInt(c / b)
  499. }
  500. }
  501. return
  502. }
  503. var hasher = sha256.New()
  504. func hashTogether(a, b []byte) []byte {
  505. hasher.Reset()
  506. hasher.Write(a)
  507. hasher.Write(b)
  508. return hasher.Sum(nil)
  509. }