You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

632 lines
19 KiB

  1. package circuitcompiler
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "github.com/arnaucube/go-snark/bn128"
  6. "github.com/arnaucube/go-snark/fields"
  7. "github.com/arnaucube/go-snark/r1csqap"
  8. "math/big"
  9. "sync"
  10. )
  11. type utils struct {
  12. Bn bn128.Bn128
  13. FqR fields.Fq
  14. PF r1csqap.PolynomialField
  15. }
  16. type R1CS struct {
  17. A [][]*big.Int
  18. B [][]*big.Int
  19. C [][]*big.Int
  20. }
  21. type Program struct {
  22. functions map[string]*Circuit
  23. globalInputs []string
  24. globalOutput map[string]bool
  25. arithmeticEnvironment utils //find a better name
  26. //key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function.
  27. //value 1 : map
  28. // with key variable name
  29. // with value variable name + hash Chain
  30. //this datastructure is nice but maybe ill replace it later with something less confusing
  31. //it serves the elementary purpose of not computing a variable a second time.
  32. //it boosts parse time
  33. computedInContext map[string]map[string]string
  34. //to reduce the number of multiplication gates, we store each factor signature, and the variable name,
  35. //so each time a variable is computed, that happens to have the very same factors, we reuse the former
  36. //it boost setup and proof time
  37. computedFactors map[string]string
  38. }
  39. //returns the cardinality of all main inputs + 1 for the "one" signal
  40. func (p *Program) GlobalInputCount() int {
  41. return len(p.globalInputs)
  42. }
  43. //returns the cardinaltiy of the output signals. Current only 1 output possible
  44. func (p *Program) GlobalOutputCount() int {
  45. return len(p.globalOutput)
  46. }
  47. func (p *Program) PrintContraintTrees() {
  48. for k, v := range p.functions {
  49. fmt.Println(k)
  50. PrintTree(v.root)
  51. }
  52. }
  53. func (p *Program) BuildConstraintTrees() {
  54. mainRoot := p.getMainCircuit().root
  55. //if our programs last operation is not a multiplication gate, we need to introduce on
  56. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  57. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  58. p.getMainCircuit().addConstraint(&newOut)
  59. mainRoot.value.Out = "main@out2"
  60. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  61. }
  62. for _, in := range p.getMainCircuit().Inputs {
  63. p.globalInputs = append(p.globalInputs, in)
  64. }
  65. var wg = sync.WaitGroup{}
  66. //we build the parse trees concurrently! because we can! go rocks
  67. for _, circuit := range p.functions {
  68. wg.Add(1)
  69. //interesting: if circuit is not passed as argument, the program fails. duno why..
  70. go func(c *Circuit) {
  71. c.buildTree(c.root)
  72. wg.Done()
  73. }(circuit)
  74. }
  75. wg.Wait()
  76. return
  77. }
  78. func (c *Circuit) buildTree(g *gate) {
  79. if _, ex := c.gateMap[g.value.Out]; ex {
  80. if g.OperationType()&(IN|CONST) != 0 {
  81. return
  82. }
  83. } else {
  84. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  85. }
  86. if g.OperationType() == FUNC {
  87. for _, in := range g.value.Inputs {
  88. if gate, ex := c.gateMap[in]; ex {
  89. g.funcInputs = append(g.funcInputs, gate)
  90. c.buildTree(gate)
  91. } else {
  92. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  93. }
  94. }
  95. return
  96. }
  97. if constr, ex := c.gateMap[g.value.V1]; ex {
  98. g.left = constr
  99. c.buildTree(g.left)
  100. } else {
  101. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  102. }
  103. if constr, ex := c.gateMap[g.value.V2]; ex {
  104. g.right = constr
  105. c.buildTree(g.right)
  106. } else {
  107. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  108. }
  109. }
  110. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  111. orderedmGates = []gate{}
  112. p.computedInContext = make(map[string]map[string]string)
  113. p.computedFactors = make(map[string]string)
  114. rootHash := []byte{}
  115. p.computedInContext[string(rootHash)] = make(map[string]string)
  116. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
  117. return orderedmGates
  118. }
  119. //recursively walks through the parse tree to create a list of all
  120. //multiplication gates needed for the QAP construction
  121. //Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
  122. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
  123. if node.OperationType() == CONST {
  124. b1, v1 := isValue(node.value.Out)
  125. if !b1 {
  126. panic("not a constant")
  127. }
  128. mul := [2]int{v1, 1}
  129. if invert {
  130. mul = [2]int{1, v1}
  131. }
  132. return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false
  133. }
  134. if node.OperationType() == FUNC {
  135. nextContext := p.extendedFunctionRenamer(currentCircuit, node.value)
  136. currentCircuit = nextContext
  137. node = nextContext.root
  138. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
  139. if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
  140. p.computedInContext[string(hashTraceBuildup)] = make(map[string]string)
  141. }
  142. }
  143. if node.OperationType() == IN {
  144. fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  145. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  146. return []factor{fac}, hashTraceBuildup, true
  147. }
  148. if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
  149. fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  150. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  151. return []factor{fac}, hashTraceBuildup, true
  152. }
  153. leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
  154. rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  155. if node.OperationType() == MULTIPLY {
  156. if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  157. //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  158. return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
  159. }
  160. sig := factorsSignature(leftFactors, rightFactors)
  161. if out, ex := p.computedFactors[sig]; ex {
  162. return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  163. }
  164. rootGate := cloneGate(node)
  165. rootGate.index = len(*orderedmGates)
  166. rootGate.leftIns = leftFactors
  167. rootGate.rightIns = rightFactors
  168. out := hashTogether(leftHash, rightHash)
  169. rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
  170. rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
  171. //note we only check for existence, but not for truth.
  172. //global outputs do not require a hash identifier, since they are unique
  173. if _, ex := p.globalOutput[rootGate.value.Out]; !ex {
  174. rootGate.value.Out = rootGate.value.Out + string(out[:10])
  175. }
  176. p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
  177. p.computedFactors[sig] = rootGate.value.Out
  178. *orderedmGates = append(*orderedmGates, *rootGate)
  179. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
  180. return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  181. }
  182. switch node.OperationType() {
  183. case PLUS:
  184. return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons
  185. default:
  186. panic("unexpected gate")
  187. }
  188. }
  189. type factor struct {
  190. typ Token
  191. name string
  192. invert, negate bool
  193. multiplicative [2]int
  194. }
  195. func (f factor) String() string {
  196. if f.typ == CONST {
  197. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  198. }
  199. str := f.name
  200. if f.invert {
  201. str += "^-1"
  202. }
  203. if f.negate {
  204. str = "-" + str
  205. }
  206. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  207. }
  208. func mul2DVector(a, b [2]int) [2]int {
  209. return [2]int{a[0] * b[0], a[1] * b[1]}
  210. }
  211. func factorsSignature(leftFactors, rightFactors []factor) string {
  212. hasher.Reset()
  213. //using a commutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a)
  214. //could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard
  215. for _, facLeft := range leftFactors {
  216. hasher.Write([]byte(facLeft.String()))
  217. }
  218. for _, Righ := range rightFactors {
  219. hasher.Write([]byte(Righ.String()))
  220. }
  221. return string(hasher.Sum(nil))[:16]
  222. }
  223. //multiplies factor elements and returns the result
  224. //in case the factors do not hold any constants and all inputs are distinct, the output will be the concatenation of left+right
  225. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  226. for _, facLeft := range leftFactors {
  227. for i, facRight := range rightFactors {
  228. if facLeft.typ == CONST && facRight.typ == IN {
  229. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  230. continue
  231. }
  232. if facRight.typ == CONST && facLeft.typ == IN {
  233. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  234. continue
  235. }
  236. if facRight.typ&facLeft.typ == CONST {
  237. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  238. continue
  239. }
  240. //tricky part here
  241. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  242. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  243. if facRight.typ&facLeft.typ == IN {
  244. if facLeft.name == facRight.name {
  245. if facRight.invert != facLeft.invert {
  246. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  247. continue
  248. }
  249. }
  250. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  251. //continue
  252. }
  253. panic("unexpected. If this errror is thrown, its probably brcause a true multiplication gate has been skipped and treated as on with constant multiplication or addition ")
  254. }
  255. }
  256. return rightFactors
  257. }
  258. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  259. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  260. func abs(n int) (val int, positive bool) {
  261. y := n >> 63
  262. return (n ^ y) - y, y == 0
  263. }
  264. //returns the reduced sum of two input factor arrays
  265. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  266. func addFactors(leftFactors, rightFactors []factor) []factor {
  267. var found bool
  268. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  269. for _, facLeft := range leftFactors {
  270. found = false
  271. for i, facRight := range rightFactors {
  272. if facLeft.typ&facRight.typ == CONST {
  273. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  274. if facLeft.negate {
  275. a0 *= -1
  276. }
  277. if facRight.negate {
  278. b0 *= -1
  279. }
  280. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  281. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  282. found = true
  283. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  284. break
  285. }
  286. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  287. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  288. if facLeft.negate {
  289. a0 *= -1
  290. }
  291. if facRight.negate {
  292. b0 *= -1
  293. }
  294. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  295. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  296. found = true
  297. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  298. break
  299. }
  300. }
  301. if !found {
  302. res = append(res, facLeft)
  303. }
  304. }
  305. for _, val := range rightFactors {
  306. if val.multiplicative[0] != 0 {
  307. res = append(res, val)
  308. }
  309. }
  310. return res
  311. }
  312. //copies a gate neglecting its references to other gates
  313. func cloneGate(in *gate) (out *gate) {
  314. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  315. nRightins := make([]factor, len(in.rightIns))
  316. nLeftInst := make([]factor, len(in.leftIns))
  317. for k, v := range in.rightIns {
  318. nRightins[k] = v
  319. }
  320. for k, v := range in.leftIns {
  321. nLeftInst[k] = v
  322. }
  323. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  324. }
  325. func (p *Program) getMainCircuit() *Circuit {
  326. return p.functions["main"]
  327. }
  328. func prepareUtils() utils {
  329. bn, err := bn128.NewBn128()
  330. if err != nil {
  331. panic(err)
  332. }
  333. // new Finite Field
  334. fqR := fields.NewFq(bn.R)
  335. // new Polynomial Field
  336. pf := r1csqap.NewPolynomialField(fqR)
  337. return utils{
  338. Bn: bn,
  339. FqR: fqR,
  340. PF: pf,
  341. }
  342. }
  343. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  344. if constraint.Op != FUNC {
  345. panic("not a function")
  346. }
  347. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  348. // panic("constraint must be within the contextCircuit circuit")
  349. //}
  350. b, n, _ := isFunction(constraint.Out)
  351. if !b {
  352. panic("not expected")
  353. }
  354. if newContext, v := p.functions[n]; v {
  355. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  356. for i, argument := range constraint.Inputs {
  357. isConst, _ := isValue(argument)
  358. if isConst {
  359. continue
  360. }
  361. isFunc, _, _ := isFunction(argument)
  362. if isFunc {
  363. panic("functions as arguments no supported yet")
  364. //p.extendedFunctionRenamer(contextCircuit,)
  365. }
  366. //at this point I assert that argument is a variable. This can become troublesome later
  367. //first we get the circuit in which the argument was created
  368. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  369. //we pick the gate that has the argument as output
  370. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  371. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  372. oldGate := newContext.gateMap[newContext.Inputs[i]]
  373. //we take the old gate which was nothing but a input
  374. //and link this input to its constituents coming from the calling contextCircuit.
  375. //i think this is pretty neat
  376. oldGate.value = gate.value
  377. oldGate.right = gate.right
  378. oldGate.left = gate.left
  379. } else {
  380. panic("not expected")
  381. }
  382. }
  383. //newContext.renameInputs(constraint.Inputs)
  384. return newContext
  385. }
  386. return nil
  387. }
  388. func NewProgram() (p *Program) {
  389. p = &Program{
  390. functions: map[string]*Circuit{},
  391. globalInputs: []string{"one"},
  392. globalOutput: map[string]bool{"main": true},
  393. arithmeticEnvironment: prepareUtils(),
  394. }
  395. return
  396. }
  397. // GenerateR1CS generates the R1CS polynomials from the Circuit
  398. func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
  399. // from flat code to R1CS
  400. offset := len(p.globalInputs)
  401. // one + in1 +in2+... + gate1 + gate2 .. + out
  402. size := offset + len(mGates)
  403. indexMap := make(map[string]int)
  404. for i, v := range p.globalInputs {
  405. indexMap[v] = i
  406. }
  407. for k, _ := range p.globalOutput {
  408. indexMap[k] = len(indexMap)
  409. }
  410. for _, v := range mGates {
  411. if _, ex := indexMap[v.value.Out]; !ex {
  412. indexMap[v.value.Out] = len(indexMap)
  413. }
  414. }
  415. for _, g := range mGates {
  416. if g.OperationType() == MULTIPLY {
  417. aConstraint := r1csqap.ArrayOfBigZeros(size)
  418. bConstraint := r1csqap.ArrayOfBigZeros(size)
  419. cConstraint := r1csqap.ArrayOfBigZeros(size)
  420. insertValue := func(val factor, arr []*big.Int) {
  421. if val.typ != CONST {
  422. if _, ex := indexMap[val.name]; !ex {
  423. panic(fmt.Sprintf("%v index not found!!!", val.name))
  424. }
  425. }
  426. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  427. if val.negate {
  428. value.Neg(value)
  429. }
  430. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  431. arr[indexMap[val.name]] = value
  432. }
  433. for _, val := range g.leftIns {
  434. insertValue(val, aConstraint)
  435. }
  436. for _, val := range g.rightIns {
  437. insertValue(val, bConstraint)
  438. }
  439. cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1))
  440. if g.value.invert {
  441. tmp := aConstraint
  442. aConstraint = cConstraint
  443. cConstraint = tmp
  444. }
  445. r1CS.A = append(r1CS.A, aConstraint)
  446. r1CS.B = append(r1CS.B, bConstraint)
  447. r1CS.C = append(r1CS.C, cConstraint)
  448. } else {
  449. panic("not a m gate")
  450. }
  451. }
  452. return
  453. }
  454. var Utils = prepareUtils()
  455. func fractionToField(in [2]int) *big.Int {
  456. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  457. }
  458. //Calculates the witness (program trace) given some input
  459. //asserts that R1CS has been computed and is stored in the program p memory calling this function
  460. func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) {
  461. witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0]))
  462. set := make([]bool, len(witness))
  463. witness[0] = big.NewInt(int64(1))
  464. set[0] = true
  465. for i := range input {
  466. witness[i+1] = input[i]
  467. set[i+1] = true
  468. }
  469. zero := big.NewInt(int64(0))
  470. for i := 0; i < len(r1cs.A); i++ {
  471. gatesLeftInputs := r1cs.A[i]
  472. gatesRightInputs := r1cs.B[i]
  473. gatesOutputs := r1cs.C[i]
  474. sumLeft := big.NewInt(int64(0))
  475. sumRight := big.NewInt(int64(0))
  476. sumOut := big.NewInt(int64(0))
  477. index := -1
  478. division := false
  479. for j, val := range gatesLeftInputs {
  480. if val.Cmp(zero) != 0 {
  481. if !set[j] {
  482. index = j
  483. division = true
  484. break
  485. }
  486. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  487. }
  488. }
  489. for j, val := range gatesRightInputs {
  490. if val.Cmp(zero) != 0 {
  491. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  492. }
  493. }
  494. for j, val := range gatesOutputs {
  495. if val.Cmp(zero) != 0 {
  496. if !set[j] {
  497. if index != -1 {
  498. panic("invalid R1CS form")
  499. }
  500. index = j
  501. break
  502. }
  503. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  504. }
  505. }
  506. if !division {
  507. set[index] = true
  508. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  509. } else {
  510. b := sumRight.Int64()
  511. c := sumOut.Int64()
  512. set[index] = true
  513. //TODO replace with proper multiplication of b^-1 within the finite field
  514. witness[index] = big.NewInt(c / b)
  515. //Utils.FqR.Mul(sumOut, Utils.FqR.Inverse(sumRight))
  516. }
  517. }
  518. return
  519. }
  520. var hasher = sha256.New()
  521. func hashTogether(a, b []byte) []byte {
  522. hasher.Reset()
  523. hasher.Write(a)
  524. hasher.Write(b)
  525. return hasher.Sum(nil)
  526. }