You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

626 lines
19 KiB

  1. package circuitcompiler
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "github.com/arnaucube/go-snark/bn128"
  6. "github.com/arnaucube/go-snark/fields"
  7. "github.com/arnaucube/go-snark/r1csqap"
  8. "math/big"
  9. "sync"
  10. )
  11. type utils struct {
  12. Bn bn128.Bn128
  13. FqR fields.Fq
  14. PF r1csqap.PolynomialField
  15. }
  16. type R1CS struct {
  17. A [][]*big.Int
  18. B [][]*big.Int
  19. C [][]*big.Int
  20. }
  21. type Program struct {
  22. functions map[string]*Circuit
  23. globalInputs []string
  24. globalOutput map[string]bool
  25. arithmeticEnvironment utils //find a better name
  26. //key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function.
  27. //value 1 : map
  28. // with key variable name
  29. // with value variable name + hash Chain
  30. //this datastructure is nice but maybe ill replace it later with something less confusing
  31. //it serves the elementary purpose of not computing a variable a second time.
  32. //it boosts parse time
  33. computedInContext map[string]map[string]string
  34. //to reduce the number of multiplication gates, we store each factor signature, and the variable name,
  35. //so each time a variable is computed, that happens to have the very same factors, we reuse the former
  36. //it boost setup and proof time
  37. computedFactors map[string]string
  38. }
  39. func (p *Program) GlobalInputCount() int {
  40. return len(p.globalInputs)
  41. }
  42. func (p *Program) PrintContraintTrees() {
  43. for k, v := range p.functions {
  44. fmt.Println(k)
  45. PrintTree(v.root)
  46. }
  47. }
  48. func (p *Program) BuildConstraintTrees() {
  49. mainRoot := p.getMainCircuit().root
  50. //if our programs last operation is not a multiplication gate, we need to introduce on
  51. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  52. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  53. p.getMainCircuit().addConstraint(&newOut)
  54. mainRoot.value.Out = "main@out2"
  55. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  56. }
  57. for _, in := range p.getMainCircuit().Inputs {
  58. p.globalInputs = append(p.globalInputs, in)
  59. }
  60. for key, _ := range p.globalOutput {
  61. p.globalInputs = append(p.globalInputs, key)
  62. }
  63. //TODO do the same with the outputs
  64. var wg = sync.WaitGroup{}
  65. //we build the parse trees concurrently! because we can! go rocks
  66. for _, circuit := range p.functions {
  67. wg.Add(1)
  68. //interesting: if circuit is not passed as argument, the program fails. duno why..
  69. go func(c *Circuit) {
  70. c.buildTree(c.root)
  71. wg.Done()
  72. }(circuit)
  73. }
  74. wg.Wait()
  75. return
  76. }
  77. func (c *Circuit) buildTree(g *gate) {
  78. if _, ex := c.gateMap[g.value.Out]; ex {
  79. if g.OperationType()&(IN|CONST) != 0 {
  80. return
  81. }
  82. } else {
  83. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  84. }
  85. if g.OperationType() == FUNC {
  86. for _, in := range g.value.Inputs {
  87. if gate, ex := c.gateMap[in]; ex {
  88. g.funcInputs = append(g.funcInputs, gate)
  89. c.buildTree(gate)
  90. } else {
  91. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  92. }
  93. }
  94. return
  95. }
  96. if constr, ex := c.gateMap[g.value.V1]; ex {
  97. g.left = constr
  98. c.buildTree(g.left)
  99. } else {
  100. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  101. }
  102. if constr, ex := c.gateMap[g.value.V2]; ex {
  103. g.right = constr
  104. c.buildTree(g.right)
  105. } else {
  106. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  107. }
  108. }
  109. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  110. orderedmGates = []gate{}
  111. p.computedInContext = make(map[string]map[string]string)
  112. p.computedFactors = make(map[string]string)
  113. rootHash := []byte{}
  114. p.computedInContext[string(rootHash)] = make(map[string]string)
  115. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
  116. return orderedmGates
  117. }
  118. //recursively walks through the parse tree to create a list of all
  119. //multiplication gates needed for the QAP construction
  120. //Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
  121. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
  122. if node.OperationType() == CONST {
  123. b1, v1 := isValue(node.value.Out)
  124. if !b1 {
  125. panic("not a constant")
  126. }
  127. mul := [2]int{v1, 1}
  128. if invert {
  129. mul = [2]int{1, v1}
  130. }
  131. return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false
  132. }
  133. if node.OperationType() == FUNC {
  134. nextContext := p.extendedFunctionRenamer(currentCircuit, node.value)
  135. currentCircuit = nextContext
  136. node = nextContext.root
  137. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
  138. if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
  139. p.computedInContext[string(hashTraceBuildup)] = make(map[string]string)
  140. }
  141. }
  142. if node.OperationType() == IN {
  143. fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  144. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  145. return []factor{fac}, hashTraceBuildup, true
  146. }
  147. if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
  148. fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  149. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  150. return []factor{fac}, hashTraceBuildup, true
  151. }
  152. leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
  153. rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  154. if node.OperationType() == MULTIPLY {
  155. if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  156. //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  157. return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
  158. }
  159. sig := factorsSignature(leftFactors, rightFactors)
  160. if out, ex := p.computedFactors[sig]; ex {
  161. return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  162. }
  163. rootGate := cloneGate(node)
  164. rootGate.index = len(*orderedmGates)
  165. rootGate.leftIns = leftFactors
  166. rootGate.rightIns = rightFactors
  167. out := hashTogether(leftHash, rightHash)
  168. rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
  169. rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
  170. //note we only check for existence, but not for truth.
  171. if _, ex := p.globalOutput[rootGate.value.Out]; !ex {
  172. rootGate.value.Out = rootGate.value.Out + string(out[:10])
  173. }
  174. p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
  175. p.computedFactors[sig] = rootGate.value.Out
  176. *orderedmGates = append(*orderedmGates, *rootGate)
  177. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
  178. return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  179. }
  180. switch node.OperationType() {
  181. case PLUS:
  182. return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons
  183. default:
  184. panic("unexpected gate")
  185. }
  186. }
  187. type factor struct {
  188. typ Token
  189. name string
  190. invert, negate bool
  191. multiplicative [2]int
  192. }
  193. func (f factor) String() string {
  194. if f.typ == CONST {
  195. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  196. }
  197. str := f.name
  198. if f.invert {
  199. str += "^-1"
  200. }
  201. if f.negate {
  202. str = "-" + str
  203. }
  204. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  205. }
  206. func mul2DVector(a, b [2]int) [2]int {
  207. return [2]int{a[0] * b[0], a[1] * b[1]}
  208. }
  209. func factorsSignature(leftFactors, rightFactors []factor) string {
  210. hasher.Reset()
  211. //using a commutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a)
  212. //could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard
  213. for _, facLeft := range leftFactors {
  214. hasher.Write([]byte(facLeft.String()))
  215. }
  216. for _, Righ := range rightFactors {
  217. hasher.Write([]byte(Righ.String()))
  218. }
  219. return string(hasher.Sum(nil))[:16]
  220. }
  221. //multiplies factor elements and returns the result
  222. //in case the factors do not hold any constants and all inputs are distinct, the output will be the concatenation of left+right
  223. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  224. for _, facLeft := range leftFactors {
  225. for i, facRight := range rightFactors {
  226. if facLeft.typ == CONST && facRight.typ == IN {
  227. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  228. continue
  229. }
  230. if facRight.typ == CONST && facLeft.typ == IN {
  231. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  232. continue
  233. }
  234. if facRight.typ&facLeft.typ == CONST {
  235. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  236. continue
  237. }
  238. //tricky part here
  239. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  240. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  241. if facRight.typ&facLeft.typ == IN {
  242. if facLeft.name == facRight.name {
  243. if facRight.invert != facLeft.invert {
  244. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  245. continue
  246. }
  247. }
  248. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  249. //continue
  250. }
  251. panic("unexpected. If this errror is thrown, its probably brcause a true multiplication gate has been skipped and treated as on with constant multiplication or addition ")
  252. }
  253. }
  254. return rightFactors
  255. }
  256. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  257. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  258. func abs(n int) (val int, positive bool) {
  259. y := n >> 63
  260. return (n ^ y) - y, y == 0
  261. }
  262. //returns the reduced sum of two input factor arrays
  263. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  264. func addFactors(leftFactors, rightFactors []factor) []factor {
  265. var found bool
  266. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  267. for _, facLeft := range leftFactors {
  268. found = false
  269. for i, facRight := range rightFactors {
  270. if facLeft.typ&facRight.typ == CONST {
  271. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  272. if facLeft.negate {
  273. a0 *= -1
  274. }
  275. if facRight.negate {
  276. b0 *= -1
  277. }
  278. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  279. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  280. found = true
  281. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  282. break
  283. }
  284. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  285. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  286. if facLeft.negate {
  287. a0 *= -1
  288. }
  289. if facRight.negate {
  290. b0 *= -1
  291. }
  292. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  293. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  294. found = true
  295. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  296. break
  297. }
  298. }
  299. if !found {
  300. res = append(res, facLeft)
  301. }
  302. }
  303. for _, val := range rightFactors {
  304. if val.multiplicative[0] != 0 {
  305. res = append(res, val)
  306. }
  307. }
  308. return res
  309. }
  310. //copies a gate neglecting its references to other gates
  311. func cloneGate(in *gate) (out *gate) {
  312. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  313. nRightins := make([]factor, len(in.rightIns))
  314. nLeftInst := make([]factor, len(in.leftIns))
  315. for k, v := range in.rightIns {
  316. nRightins[k] = v
  317. }
  318. for k, v := range in.leftIns {
  319. nLeftInst[k] = v
  320. }
  321. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  322. }
  323. func (p *Program) getMainCircuit() *Circuit {
  324. return p.functions["main"]
  325. }
  326. func prepareUtils() utils {
  327. bn, err := bn128.NewBn128()
  328. if err != nil {
  329. panic(err)
  330. }
  331. // new Finite Field
  332. fqR := fields.NewFq(bn.R)
  333. // new Polynomial Field
  334. pf := r1csqap.NewPolynomialField(fqR)
  335. return utils{
  336. Bn: bn,
  337. FqR: fqR,
  338. PF: pf,
  339. }
  340. }
  341. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  342. if constraint.Op != FUNC {
  343. panic("not a function")
  344. }
  345. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  346. // panic("constraint must be within the contextCircuit circuit")
  347. //}
  348. b, n, _ := isFunction(constraint.Out)
  349. if !b {
  350. panic("not expected")
  351. }
  352. if newContext, v := p.functions[n]; v {
  353. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  354. for i, argument := range constraint.Inputs {
  355. isConst, _ := isValue(argument)
  356. if isConst {
  357. continue
  358. }
  359. isFunc, _, _ := isFunction(argument)
  360. if isFunc {
  361. panic("functions as arguments no supported yet")
  362. //p.extendedFunctionRenamer(contextCircuit,)
  363. }
  364. //at this point I assert that argument is a variable. This can become troublesome later
  365. //first we get the circuit in which the argument was created
  366. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  367. //we pick the gate that has the argument as output
  368. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  369. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  370. oldGate := newContext.gateMap[newContext.Inputs[i]]
  371. //we take the old gate which was nothing but a input
  372. //and link this input to its constituents coming from the calling contextCircuit.
  373. //i think this is pretty neat
  374. oldGate.value = gate.value
  375. oldGate.right = gate.right
  376. oldGate.left = gate.left
  377. } else {
  378. panic("not expected")
  379. }
  380. }
  381. //newContext.renameInputs(constraint.Inputs)
  382. return newContext
  383. }
  384. return nil
  385. }
  386. func NewProgram() (p *Program) {
  387. p = &Program{
  388. functions: map[string]*Circuit{},
  389. globalInputs: []string{"one"},
  390. globalOutput: map[string]bool{"main": true},
  391. arithmeticEnvironment: prepareUtils(),
  392. }
  393. return
  394. }
  395. // GenerateR1CS generates the R1CS polynomials from the Circuit
  396. func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
  397. // from flat code to R1CS
  398. offset := len(p.globalInputs) - len(p.globalOutput)
  399. // one + in1 +in2+... + gate1 + gate2 .. + out
  400. size := offset + len(mGates)
  401. indexMap := make(map[string]int)
  402. for i, v := range p.globalInputs {
  403. indexMap[v] = i
  404. }
  405. for _, v := range mGates {
  406. if _, ex := indexMap[v.value.Out]; !ex {
  407. indexMap[v.value.Out] = len(indexMap)
  408. }
  409. }
  410. for _, g := range mGates {
  411. if g.OperationType() == MULTIPLY {
  412. aConstraint := r1csqap.ArrayOfBigZeros(size)
  413. bConstraint := r1csqap.ArrayOfBigZeros(size)
  414. cConstraint := r1csqap.ArrayOfBigZeros(size)
  415. insertValue := func(val factor, arr []*big.Int) {
  416. if val.typ != CONST {
  417. if _, ex := indexMap[val.name]; !ex {
  418. panic(fmt.Sprintf("%v index not found!!!", val.name))
  419. }
  420. }
  421. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  422. if val.negate {
  423. value.Neg(value)
  424. }
  425. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  426. arr[indexMap[val.name]] = value
  427. }
  428. for _, val := range g.leftIns {
  429. insertValue(val, aConstraint)
  430. }
  431. for _, val := range g.rightIns {
  432. insertValue(val, bConstraint)
  433. }
  434. cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1))
  435. if g.value.invert {
  436. tmp := aConstraint
  437. aConstraint = cConstraint
  438. cConstraint = tmp
  439. }
  440. r1CS.A = append(r1CS.A, aConstraint)
  441. r1CS.B = append(r1CS.B, bConstraint)
  442. r1CS.C = append(r1CS.C, cConstraint)
  443. } else {
  444. panic("not a m gate")
  445. }
  446. }
  447. return
  448. }
  449. var Utils = prepareUtils()
  450. func fractionToField(in [2]int) *big.Int {
  451. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  452. }
  453. //Calculates the witness (program trace) given some input
  454. //asserts that R1CS has been computed and is stored in the program p memory calling this function
  455. func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) {
  456. witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0]))
  457. set := make([]bool, len(witness))
  458. witness[0] = big.NewInt(int64(1))
  459. set[0] = true
  460. for i := range input {
  461. witness[i+1] = input[i]
  462. set[i+1] = true
  463. }
  464. zero := big.NewInt(int64(0))
  465. for i := 0; i < len(r1cs.A); i++ {
  466. gatesLeftInputs := r1cs.A[i]
  467. gatesRightInputs := r1cs.B[i]
  468. gatesOutputs := r1cs.C[i]
  469. sumLeft := big.NewInt(int64(0))
  470. sumRight := big.NewInt(int64(0))
  471. sumOut := big.NewInt(int64(0))
  472. index := -1
  473. division := false
  474. for j, val := range gatesLeftInputs {
  475. if val.Cmp(zero) != 0 {
  476. if !set[j] {
  477. index = j
  478. division = true
  479. break
  480. }
  481. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  482. }
  483. }
  484. for j, val := range gatesRightInputs {
  485. if val.Cmp(zero) != 0 {
  486. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  487. }
  488. }
  489. for j, val := range gatesOutputs {
  490. if val.Cmp(zero) != 0 {
  491. if !set[j] {
  492. if index != -1 {
  493. panic("invalid R1CS form")
  494. }
  495. index = j
  496. break
  497. }
  498. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  499. }
  500. }
  501. if !division {
  502. set[index] = true
  503. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  504. } else {
  505. b := sumRight.Int64()
  506. c := sumOut.Int64()
  507. set[index] = true
  508. //TODO replace with proper multiplication of b^-1 within the finite field
  509. witness[index] = big.NewInt(c / b)
  510. //Utils.FqR.Mul(sumOut, Utils.FqR.Inverse(sumRight))
  511. }
  512. }
  513. return
  514. }
  515. var hasher = sha256.New()
  516. func hashTogether(a, b []byte) []byte {
  517. hasher.Reset()
  518. hasher.Write(a)
  519. hasher.Write(b)
  520. return hasher.Sum(nil)
  521. }