You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

612 lines
18 KiB

  1. package circuitcompiler
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "github.com/arnaucube/go-snark/bn128"
  6. "github.com/arnaucube/go-snark/fields"
  7. "github.com/arnaucube/go-snark/r1csqap"
  8. "hash"
  9. "math/big"
  10. "sync"
  11. )
  12. type utils struct {
  13. Bn bn128.Bn128
  14. FqR fields.Fq
  15. PF r1csqap.PolynomialField
  16. }
  17. type R1CS struct {
  18. A [][]*big.Int
  19. B [][]*big.Int
  20. C [][]*big.Int
  21. }
  22. type Program struct {
  23. functions map[string]*Circuit
  24. globalInputs []string
  25. arithmeticEnvironment utils //find a better name
  26. sha256Hasher hash.Hash
  27. //key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function.
  28. //value 1 : map
  29. // with key variable name
  30. // with value variable name + hash Chain
  31. //this datastructure is nice but maybe ill replace it later with something less confusing
  32. //it serves the elementary purpose of not computing a variable a second time
  33. computedInContext map[string]map[string]string
  34. //to reduce the number of multiplication gates, we store each factor signature, and the variable name,
  35. //so each time a variable is computed, that happens to have the very same factors, we reuse the former gate
  36. computedFactors map[string]string
  37. }
  38. func (p *Program) GlobalInputCount() int {
  39. return len(p.globalInputs)
  40. }
  41. func (p *Program) PrintContraintTrees() {
  42. for k, v := range p.functions {
  43. fmt.Println(k)
  44. PrintTree(v.root)
  45. }
  46. }
  47. func (p *Program) BuildConstraintTrees() {
  48. mainRoot := p.getMainCircuit().root
  49. //if our programs last operation is not a multiplication gate, we need to introduce on
  50. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  51. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  52. p.getMainCircuit().addConstraint(&newOut)
  53. mainRoot.value.Out = "main@out2"
  54. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  55. }
  56. for _, in := range p.getMainCircuit().Inputs {
  57. p.globalInputs = append(p.globalInputs, in)
  58. }
  59. var wg = sync.WaitGroup{}
  60. //we build the parse trees concurrently! because we can! go rocks
  61. for _, circuit := range p.functions {
  62. wg.Add(1)
  63. //interesting: if circuit is not passed as argument, the program fails. duno why..
  64. go func(c *Circuit) {
  65. c.buildTree(c.root)
  66. wg.Done()
  67. }(circuit)
  68. }
  69. wg.Wait()
  70. return
  71. }
  72. func (c *Circuit) buildTree(g *gate) {
  73. if _, ex := c.gateMap[g.value.Out]; ex {
  74. if g.OperationType()&(IN|CONST) != 0 {
  75. return
  76. }
  77. } else {
  78. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  79. }
  80. if g.OperationType() == FUNC {
  81. for _, in := range g.value.Inputs {
  82. if gate, ex := c.gateMap[in]; ex {
  83. g.funcInputs = append(g.funcInputs, gate)
  84. c.buildTree(gate)
  85. } else {
  86. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  87. }
  88. }
  89. return
  90. }
  91. if constr, ex := c.gateMap[g.value.V1]; ex {
  92. g.left = constr
  93. c.buildTree(g.left)
  94. } else {
  95. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  96. }
  97. if constr, ex := c.gateMap[g.value.V2]; ex {
  98. g.right = constr
  99. c.buildTree(g.right)
  100. } else {
  101. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  102. }
  103. }
  104. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  105. orderedmGates = []gate{}
  106. p.computedInContext = make(map[string]map[string]string)
  107. p.computedFactors = make(map[string]string)
  108. rootHash := []byte{}
  109. p.computedInContext[string(rootHash)] = make(map[string]string)
  110. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
  111. return orderedmGates
  112. }
  113. //recursively walks through the parse tree to create a list of all
  114. //multiplication gates needed for the QAP construction
  115. //Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
  116. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
  117. if node.OperationType() == CONST {
  118. b1, v1 := isValue(node.value.Out)
  119. if !b1 {
  120. panic("not a constant")
  121. }
  122. mul := [2]int{v1, 1}
  123. if invert {
  124. mul = [2]int{1, v1}
  125. }
  126. return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false
  127. }
  128. if node.OperationType() == FUNC {
  129. nextContext := p.extendedFunctionRenamer(currentCircuit, node.value)
  130. currentCircuit = nextContext
  131. node = nextContext.root
  132. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
  133. if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
  134. p.computedInContext[string(hashTraceBuildup)] = make(map[string]string)
  135. }
  136. }
  137. if node.OperationType() == IN {
  138. fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  139. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  140. return []factor{fac}, hashTraceBuildup, true
  141. }
  142. if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
  143. fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  144. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  145. return []factor{fac}, hashTraceBuildup, true
  146. }
  147. leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
  148. rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  149. if node.OperationType() == MULTIPLY {
  150. if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  151. //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  152. return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
  153. }
  154. sig := factorsSignature(leftFactors, rightFactors)
  155. if out, ex := p.computedFactors[sig]; ex {
  156. return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  157. }
  158. rootGate := cloneGate(node)
  159. rootGate.index = len(*orderedmGates)
  160. rootGate.leftIns = leftFactors
  161. rootGate.rightIns = rightFactors
  162. out := hashTogether(leftHash, rightHash)
  163. rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
  164. rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
  165. rootGate.value.Out = rootGate.value.Out + string(out[:10])
  166. p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
  167. p.computedFactors[sig] = rootGate.value.Out
  168. *orderedmGates = append(*orderedmGates, *rootGate)
  169. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
  170. return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  171. }
  172. switch node.OperationType() {
  173. case PLUS:
  174. return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons
  175. default:
  176. panic("unexpected gate")
  177. }
  178. }
  179. type factor struct {
  180. typ Token
  181. name string
  182. invert, negate bool
  183. multiplicative [2]int
  184. }
  185. func (f factor) String() string {
  186. if f.typ == CONST {
  187. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  188. }
  189. str := f.name
  190. if f.invert {
  191. str += "^-1"
  192. }
  193. if f.negate {
  194. str = "-" + str
  195. }
  196. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  197. }
  198. func mul2DVector(a, b [2]int) [2]int {
  199. return [2]int{a[0] * b[0], a[1] * b[1]}
  200. }
  201. func factorsSignature(leftFactors, rightFactors []factor) string {
  202. hasher.Reset()
  203. //not using a kommutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a)
  204. //could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard
  205. for _, facLeft := range leftFactors {
  206. hasher.Write([]byte(facLeft.String()))
  207. }
  208. for _, Righ := range rightFactors {
  209. hasher.Write([]byte(Righ.String()))
  210. }
  211. return string(hasher.Sum(nil))[:16]
  212. }
  213. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  214. for _, facLeft := range leftFactors {
  215. for i, facRight := range rightFactors {
  216. if facLeft.typ == CONST && facRight.typ == IN {
  217. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  218. continue
  219. }
  220. if facRight.typ == CONST && facLeft.typ == IN {
  221. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  222. continue
  223. }
  224. if facRight.typ&facLeft.typ == CONST {
  225. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  226. continue
  227. }
  228. //tricky part here
  229. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  230. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  231. if facRight.typ&facLeft.typ == IN {
  232. if facLeft.name == facRight.name {
  233. if facRight.invert != facLeft.invert {
  234. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  235. continue
  236. }
  237. }
  238. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  239. //continue
  240. }
  241. panic("unexpected")
  242. }
  243. }
  244. return rightFactors
  245. }
  246. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  247. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  248. func abs(n int) (val int, positive bool) {
  249. y := n >> 63
  250. return (n ^ y) - y, y == 0
  251. }
  252. //returns the reduced sum of two input factor arrays
  253. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  254. func addFactors(leftFactors, rightFactors []factor) []factor {
  255. var found bool
  256. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  257. for _, facLeft := range leftFactors {
  258. found = false
  259. for i, facRight := range rightFactors {
  260. if facLeft.typ&facRight.typ == CONST {
  261. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  262. if facLeft.negate {
  263. a0 *= -1
  264. }
  265. if facRight.negate {
  266. b0 *= -1
  267. }
  268. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  269. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  270. found = true
  271. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  272. break
  273. }
  274. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  275. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  276. if facLeft.negate {
  277. a0 *= -1
  278. }
  279. if facRight.negate {
  280. b0 *= -1
  281. }
  282. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  283. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  284. found = true
  285. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  286. break
  287. }
  288. }
  289. if !found {
  290. res = append(res, facLeft)
  291. }
  292. }
  293. for _, val := range rightFactors {
  294. if val.multiplicative[0] != 0 {
  295. res = append(res, val)
  296. }
  297. }
  298. return res
  299. }
  300. //copies a gate neglecting its references to other gates
  301. func cloneGate(in *gate) (out *gate) {
  302. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  303. nRightins := make([]factor, len(in.rightIns))
  304. nLeftInst := make([]factor, len(in.leftIns))
  305. for k, v := range in.rightIns {
  306. nRightins[k] = v
  307. }
  308. for k, v := range in.leftIns {
  309. nLeftInst[k] = v
  310. }
  311. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  312. }
  313. func (p *Program) getMainCircuit() *Circuit {
  314. return p.functions["main"]
  315. }
  316. func prepareUtils() utils {
  317. bn, err := bn128.NewBn128()
  318. if err != nil {
  319. panic(err)
  320. }
  321. // new Finite Field
  322. fqR := fields.NewFq(bn.R)
  323. // new Polynomial Field
  324. pf := r1csqap.NewPolynomialField(fqR)
  325. return utils{
  326. Bn: bn,
  327. FqR: fqR,
  328. PF: pf,
  329. }
  330. }
  331. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  332. if constraint.Op != FUNC {
  333. panic("not a function")
  334. }
  335. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  336. // panic("constraint must be within the contextCircuit circuit")
  337. //}
  338. b, n, _ := isFunction(constraint.Out)
  339. if !b {
  340. panic("not expected")
  341. }
  342. if newContext, v := p.functions[n]; v {
  343. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  344. for i, argument := range constraint.Inputs {
  345. isConst, _ := isValue(argument)
  346. if isConst {
  347. continue
  348. }
  349. isFunc, _, _ := isFunction(argument)
  350. if isFunc {
  351. panic("functions as arguments no supported yet")
  352. //p.extendedFunctionRenamer(contextCircuit,)
  353. }
  354. //at this point I assert that argument is a variable. This can become troublesome later
  355. //first we get the circuit in which the argument was created
  356. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  357. //we pick the gate that has the argument as output
  358. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  359. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  360. oldGate := newContext.gateMap[newContext.Inputs[i]]
  361. //we take the old gate which was nothing but a input
  362. //and link this input to its constituents coming from the calling contextCircuit.
  363. //i think this is pretty neat
  364. oldGate.value = gate.value
  365. oldGate.right = gate.right
  366. oldGate.left = gate.left
  367. } else {
  368. panic("not expected")
  369. }
  370. }
  371. //newContext.renameInputs(constraint.Inputs)
  372. return newContext
  373. }
  374. return nil
  375. }
  376. func NewProgram() (p *Program) {
  377. p = &Program{
  378. functions: map[string]*Circuit{},
  379. globalInputs: []string{"one"},
  380. arithmeticEnvironment: prepareUtils(),
  381. sha256Hasher: sha256.New(),
  382. }
  383. return
  384. }
  385. // GenerateR1CS generates the R1CS polynomials from the Circuit
  386. func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
  387. // from flat code to R1CS
  388. offset := len(p.globalInputs)
  389. // one + in1 +in2+... + gate1 + gate2 .. + out
  390. size := offset + len(mGates)
  391. indexMap := make(map[string]int)
  392. for i, v := range p.globalInputs {
  393. indexMap[v] = i
  394. }
  395. for i, v := range mGates {
  396. indexMap[v.value.Out] = i + offset
  397. }
  398. for _, g := range mGates {
  399. if g.OperationType() == MULTIPLY {
  400. aConstraint := r1csqap.ArrayOfBigZeros(size)
  401. bConstraint := r1csqap.ArrayOfBigZeros(size)
  402. cConstraint := r1csqap.ArrayOfBigZeros(size)
  403. insertValue := func(val factor, arr []*big.Int) {
  404. if val.typ != CONST {
  405. if _, ex := indexMap[val.name]; !ex {
  406. panic(fmt.Sprintf("%v index not found!!!", val.name))
  407. }
  408. }
  409. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  410. if val.negate {
  411. value.Neg(value)
  412. }
  413. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  414. arr[indexMap[val.name]] = value
  415. }
  416. for _, val := range g.leftIns {
  417. insertValue(val, aConstraint)
  418. }
  419. for _, val := range g.rightIns {
  420. insertValue(val, bConstraint)
  421. }
  422. cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1))
  423. if g.value.invert {
  424. tmp := aConstraint
  425. aConstraint = cConstraint
  426. cConstraint = tmp
  427. }
  428. r1CS.A = append(r1CS.A, aConstraint)
  429. r1CS.B = append(r1CS.B, bConstraint)
  430. r1CS.C = append(r1CS.C, cConstraint)
  431. } else {
  432. panic("not a m gate")
  433. }
  434. }
  435. return
  436. }
  437. var Utils = prepareUtils()
  438. func fractionToField(in [2]int) *big.Int {
  439. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  440. }
  441. //Calculates the witness (program trace) given some input
  442. //asserts that R1CS has been computed and is stored in the program p memory calling this function
  443. func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) {
  444. //if len(p.globalInputs)-1 != len(input) {
  445. // panic("input do not match the required inputs")
  446. //}
  447. witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0]))
  448. set := make([]bool, len(witness))
  449. witness[0] = big.NewInt(int64(1))
  450. set[0] = true
  451. for i := range input {
  452. witness[i+1] = input[i]
  453. set[i+1] = true
  454. }
  455. zero := big.NewInt(int64(0))
  456. for i := 0; i < len(r1cs.A); i++ {
  457. gatesLeftInputs := r1cs.A[i]
  458. gatesRightInputs := r1cs.B[i]
  459. gatesOutputs := r1cs.C[i]
  460. sumLeft := big.NewInt(int64(0))
  461. sumRight := big.NewInt(int64(0))
  462. sumOut := big.NewInt(int64(0))
  463. index := -1
  464. division := false
  465. for j, val := range gatesLeftInputs {
  466. if val.Cmp(zero) != 0 {
  467. if !set[j] {
  468. index = j
  469. division = true
  470. break
  471. }
  472. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  473. }
  474. }
  475. for j, val := range gatesRightInputs {
  476. if val.Cmp(zero) != 0 {
  477. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  478. }
  479. }
  480. for j, val := range gatesOutputs {
  481. if val.Cmp(zero) != 0 {
  482. if !set[j] {
  483. if index != -1 {
  484. panic("invalid R1CS form")
  485. }
  486. index = j
  487. break
  488. }
  489. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  490. }
  491. }
  492. if !division {
  493. set[index] = true
  494. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  495. } else {
  496. b := sumRight.Int64()
  497. c := sumOut.Int64()
  498. set[index] = true
  499. witness[index] = big.NewInt(c / b)
  500. }
  501. }
  502. return
  503. }
  504. var hasher = sha256.New()
  505. func hashTogether(a, b []byte) []byte {
  506. hasher.Reset()
  507. hasher.Write(a)
  508. hasher.Write(b)
  509. return hasher.Sum(nil)
  510. }