You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

597 lines
17 KiB

  1. package circuitcompiler
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "github.com/mottla/go-snark/bn128"
  6. "github.com/mottla/go-snark/fields"
  7. "github.com/mottla/go-snark/r1csqap"
  8. "hash"
  9. "math/big"
  10. "sync"
  11. )
  12. type utils struct {
  13. Bn bn128.Bn128
  14. FqR fields.Fq
  15. PF r1csqap.PolynomialField
  16. }
  17. type Program struct {
  18. functions map[string]*Circuit
  19. globalInputs []string
  20. arithmeticEnvironment utils //find a better name
  21. sha256Hasher hash.Hash
  22. computedInContext map[string]map[string]string
  23. R1CS struct {
  24. A [][]*big.Int
  25. B [][]*big.Int
  26. C [][]*big.Int
  27. }
  28. }
  29. func (p *Program) PrintContraintTrees() {
  30. for k, v := range p.functions {
  31. fmt.Println(k)
  32. PrintTree(v.root)
  33. }
  34. }
  35. func (p *Program) BuildConstraintTrees() {
  36. mainRoot := p.getMainCircuit().root
  37. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  38. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  39. p.getMainCircuit().addConstraint(&newOut)
  40. mainRoot.value.Out = "main@out2"
  41. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  42. }
  43. //for _, in := range p.getMainCircuit().Inputs {
  44. // p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs))
  45. //}
  46. for _, in := range p.getMainCircuit().Inputs {
  47. p.globalInputs = append(p.globalInputs, in)
  48. }
  49. var wg = sync.WaitGroup{}
  50. for _, circuit := range p.functions {
  51. wg.Add(1)
  52. func() {
  53. circuit.buildTree(circuit.root)
  54. wg.Done()
  55. }()
  56. }
  57. wg.Wait()
  58. return
  59. }
  60. func (c *Circuit) buildTree(g *gate) {
  61. if _, ex := c.gateMap[g.value.Out]; ex {
  62. if g.OperationType()&(IN|CONST) != 0 {
  63. return
  64. }
  65. } else {
  66. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  67. }
  68. if g.OperationType() == FUNC {
  69. //g.funcInputs = []*gate{}
  70. for _, in := range g.value.Inputs {
  71. if gate, ex := c.gateMap[in]; ex {
  72. g.funcInputs = append(g.funcInputs, gate)
  73. //note that we do repeated work here. the argument
  74. c.buildTree(gate)
  75. } else {
  76. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  77. }
  78. }
  79. return
  80. }
  81. if constr, ex := c.gateMap[g.value.V1]; ex {
  82. g.left = constr
  83. c.buildTree(g.left)
  84. } else {
  85. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  86. }
  87. if constr, ex := c.gateMap[g.value.V2]; ex {
  88. g.right = constr
  89. c.buildTree(g.right)
  90. } else {
  91. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  92. }
  93. }
  94. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  95. //mGatesUsed := make(map[string]bool)
  96. orderedmGates = []gate{}
  97. p.computedInContext = make(map[string]map[string]string)
  98. rootHash := []byte{}
  99. p.computedInContext[string(rootHash)] = make(map[string]string)
  100. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
  101. return orderedmGates
  102. }
  103. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
  104. if node.OperationType() == CONST {
  105. b1, v1 := isValue(node.value.Out)
  106. if !b1 {
  107. panic("not a constant")
  108. }
  109. mul := [2]int{v1, 1}
  110. if invert {
  111. mul = [2]int{1, v1}
  112. }
  113. return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false
  114. }
  115. if node.OperationType() == FUNC {
  116. nextContext := p.extendedFunctionRenamer(currentCircuit, node.value)
  117. currentCircuit = nextContext
  118. node = nextContext.root
  119. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
  120. if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
  121. p.computedInContext[string(hashTraceBuildup)] = make(map[string]string)
  122. }
  123. }
  124. if node.OperationType() == IN {
  125. fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  126. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  127. return []factor{fac}, hashTraceBuildup, true
  128. }
  129. if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
  130. fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  131. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
  132. return []factor{fac}, hashTraceBuildup, true
  133. }
  134. leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
  135. rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  136. if node.OperationType() == MULTIPLY {
  137. if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  138. //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  139. return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
  140. }
  141. rootGate := cloneGate(node)
  142. rootGate.index = len(*orderedmGates)
  143. rootGate.leftIns = leftFactors
  144. rootGate.rightIns = rightFactors
  145. out := hashTogether(leftHash, rightHash)
  146. rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
  147. rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
  148. rootGate.value.Out = rootGate.value.Out + string(out[:10])
  149. p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
  150. *orderedmGates = append(*orderedmGates, *rootGate)
  151. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
  152. return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
  153. }
  154. switch node.OperationType() {
  155. case PLUS:
  156. return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons
  157. default:
  158. panic("unexpected gate")
  159. }
  160. //TODO optimize if output is not a multipication gate
  161. }
  162. type factor struct {
  163. typ Token
  164. name string
  165. invert, negate bool
  166. multiplicative [2]int
  167. }
  168. func (f factor) String() string {
  169. if f.typ == CONST {
  170. return fmt.Sprintf("(const fac: %v)", f.multiplicative)
  171. }
  172. str := f.name
  173. if f.invert {
  174. str += "^-1"
  175. }
  176. if f.negate {
  177. str = "-" + str
  178. }
  179. return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
  180. }
  181. func mul2DVector(a, b [2]int) [2]int {
  182. return [2]int{a[0] * b[0], a[1] * b[1]}
  183. }
  184. func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
  185. for _, facLeft := range leftFactors {
  186. for i, facRight := range rightFactors {
  187. if facLeft.typ == CONST && facRight.typ == IN {
  188. rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  189. continue
  190. }
  191. if facRight.typ == CONST && facLeft.typ == IN {
  192. rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  193. continue
  194. }
  195. if facRight.typ&facLeft.typ == CONST {
  196. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  197. continue
  198. }
  199. //tricky part here
  200. //this one should only be reached, after a true mgate had its left and right braches computed. here we
  201. //a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
  202. if facRight.typ&facLeft.typ == IN {
  203. if facLeft.name == facRight.name {
  204. if facRight.invert != facLeft.invert {
  205. rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  206. continue
  207. }
  208. }
  209. //rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
  210. //continue
  211. }
  212. fmt.Println("dsf")
  213. panic("unexpected")
  214. }
  215. }
  216. return rightFactors
  217. }
  218. //returns the absolute value of a signed int and a flag telling if the input was positive or not
  219. //this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
  220. func abs(n int) (val int, positive bool) {
  221. y := n >> 63
  222. return (n ^ y) - y, y == 0
  223. }
  224. //returns the reduced sum of two input factor arrays
  225. //if no reduction was done (worst case), it returns the concatenation of the input arrays
  226. func addFactors(leftFactors, rightFactors []factor) []factor {
  227. var found bool
  228. res := make([]factor, 0, len(leftFactors)+len(rightFactors))
  229. for _, facLeft := range leftFactors {
  230. found = false
  231. for i, facRight := range rightFactors {
  232. if facLeft.typ&facRight.typ == CONST {
  233. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  234. if facLeft.negate {
  235. a0 *= -1
  236. }
  237. if facRight.negate {
  238. b0 *= -1
  239. }
  240. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  241. rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  242. found = true
  243. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  244. break
  245. }
  246. if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
  247. var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
  248. if facLeft.negate {
  249. a0 *= -1
  250. }
  251. if facRight.negate {
  252. b0 *= -1
  253. }
  254. absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
  255. rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
  256. found = true
  257. //res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
  258. break
  259. }
  260. }
  261. if !found {
  262. res = append(res, facLeft)
  263. }
  264. }
  265. for _, val := range rightFactors {
  266. if val.multiplicative[0] != 0 {
  267. res = append(res, val)
  268. }
  269. }
  270. return res
  271. }
  272. //copies a gate neglecting its references to other gates
  273. func cloneGate(in *gate) (out *gate) {
  274. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  275. nRightins := make([]factor, len(in.rightIns))
  276. nLeftInst := make([]factor, len(in.leftIns))
  277. for k, v := range in.rightIns {
  278. nRightins[k] = v
  279. }
  280. for k, v := range in.leftIns {
  281. nLeftInst[k] = v
  282. }
  283. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  284. }
  285. func (p *Program) getMainCircuit() *Circuit {
  286. return p.functions["main"]
  287. }
  288. //func (p *Program) addGlobalInput(c Constraint) {
  289. // c.Out = "main@" + c.Out
  290. // p.globalInputs = append(p.globalInputs, c)
  291. //}
  292. func prepareUtils() utils {
  293. bn, err := bn128.NewBn128()
  294. if err != nil {
  295. panic(err)
  296. }
  297. // new Finite Field
  298. fqR := fields.NewFq(bn.R)
  299. // new Polynomial Field
  300. pf := r1csqap.NewPolynomialField(fqR)
  301. return utils{
  302. Bn: bn,
  303. FqR: fqR,
  304. PF: pf,
  305. }
  306. }
  307. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  308. if constraint.Op != FUNC {
  309. panic("not a function")
  310. }
  311. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  312. // panic("constraint must be within the contextCircuit circuit")
  313. //}
  314. b, n, _ := isFunction(constraint.Out)
  315. if !b {
  316. panic("not expected")
  317. }
  318. if newContext, v := p.functions[n]; v {
  319. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  320. for i, argument := range constraint.Inputs {
  321. isConst, _ := isValue(argument)
  322. if isConst {
  323. continue
  324. }
  325. isFunc, _, _ := isFunction(argument)
  326. if isFunc {
  327. panic("functions as arguments no supported yet")
  328. //p.extendedFunctionRenamer(contextCircuit,)
  329. }
  330. //at this point I assert that argument is a variable. This can become troublesome later
  331. //first we get the circuit in which the argument was created
  332. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  333. //we pick the gate that has the argument as output
  334. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  335. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  336. oldGate := newContext.gateMap[newContext.Inputs[i]]
  337. //we take the old gate which was nothing but a input
  338. //and link this input to its constituents coming from the calling contextCircuit.
  339. //i think this is pretty neat
  340. oldGate.value = gate.value
  341. oldGate.right = gate.right
  342. oldGate.left = gate.left
  343. } else {
  344. panic("not expected")
  345. }
  346. }
  347. //newContext.renameInputs(constraint.Inputs)
  348. return newContext
  349. }
  350. return nil
  351. }
  352. func NewProgram() (p *Program) {
  353. p = &Program{
  354. functions: map[string]*Circuit{},
  355. globalInputs: []string{"one"},
  356. arithmeticEnvironment: prepareUtils(),
  357. sha256Hasher: sha256.New(),
  358. }
  359. return
  360. }
  361. // GenerateR1CS generates the R1CS polynomials from the Circuit
  362. func (p *Program) GenerateReducedR1CS(mGates []gate) (a, b, c [][]*big.Int) {
  363. // from flat code to R1CS
  364. offset := len(p.globalInputs)
  365. // one + in1 +in2+... + gate1 + gate2 .. + out
  366. size := offset + len(mGates)
  367. indexMap := make(map[string]int)
  368. for i, v := range p.globalInputs {
  369. indexMap[v] = i
  370. }
  371. for i, v := range mGates {
  372. indexMap[v.value.Out] = i + offset
  373. }
  374. for _, gate := range mGates {
  375. if gate.OperationType() == MULTIPLY {
  376. aConstraint := r1csqap.ArrayOfBigZeros(size)
  377. bConstraint := r1csqap.ArrayOfBigZeros(size)
  378. cConstraint := r1csqap.ArrayOfBigZeros(size)
  379. for _, val := range gate.leftIns {
  380. if val.typ != CONST {
  381. if _, ex := indexMap[val.name]; !ex {
  382. panic(fmt.Sprintf("%v index not found!!!", val.name))
  383. }
  384. }
  385. convertAndInsertFactorAt(aConstraint, val, indexMap[val.name])
  386. }
  387. for _, val := range gate.rightIns {
  388. if val.typ != CONST {
  389. if _, ex := indexMap[val.name]; !ex {
  390. panic(fmt.Sprintf("%v index not found!!!", val.name))
  391. }
  392. }
  393. convertAndInsertFactorAt(bConstraint, val, indexMap[val.name])
  394. }
  395. cConstraint[indexMap[gate.value.Out]] = big.NewInt(int64(1))
  396. if gate.value.invert {
  397. tmp := aConstraint
  398. aConstraint = cConstraint
  399. cConstraint = tmp
  400. }
  401. a = append(a, aConstraint)
  402. b = append(b, bConstraint)
  403. c = append(c, cConstraint)
  404. } else {
  405. panic("not a m gate")
  406. }
  407. }
  408. p.R1CS.A = a
  409. p.R1CS.B = b
  410. p.R1CS.C = c
  411. return a, b, c
  412. }
  413. var Utils = prepareUtils()
  414. func fractionToField(in [2]int) *big.Int {
  415. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  416. }
  417. func convertAndInsertFactorAt(arr []*big.Int, val factor, index int) {
  418. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  419. if val.negate {
  420. value.Neg(value)
  421. }
  422. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  423. arr[index] = value
  424. }
  425. func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
  426. if len(p.globalInputs)-1 != len(input) {
  427. panic("input do not match the required inputs")
  428. }
  429. witness = r1csqap.ArrayOfBigZeros(len(p.R1CS.A[0]))
  430. set := make([]bool, len(witness))
  431. witness[0] = big.NewInt(int64(1))
  432. set[0] = true
  433. for i := range input {
  434. witness[i+1] = input[i]
  435. set[i+1] = true
  436. }
  437. zero := big.NewInt(int64(0))
  438. for i := 0; i < len(p.R1CS.A); i++ {
  439. gatesLeftInputs := p.R1CS.A[i]
  440. gatesRightInputs := p.R1CS.B[i]
  441. gatesOutputs := p.R1CS.C[i]
  442. sumLeft := big.NewInt(int64(0))
  443. sumRight := big.NewInt(int64(0))
  444. sumOut := big.NewInt(int64(0))
  445. index := -1
  446. division := false
  447. for j, val := range gatesLeftInputs {
  448. if val.Cmp(zero) != 0 {
  449. if !set[j] {
  450. index = j
  451. division = true
  452. break
  453. }
  454. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  455. }
  456. }
  457. for j, val := range gatesRightInputs {
  458. if val.Cmp(zero) != 0 {
  459. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  460. }
  461. }
  462. for j, val := range gatesOutputs {
  463. if val.Cmp(zero) != 0 {
  464. if !set[j] {
  465. if index != -1 {
  466. panic("invalid R1CS form")
  467. }
  468. index = j
  469. break
  470. }
  471. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  472. }
  473. }
  474. if !division {
  475. set[index] = true
  476. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  477. } else {
  478. b := sumRight.Int64()
  479. c := sumOut.Int64()
  480. set[index] = true
  481. witness[index] = big.NewInt(c / b)
  482. }
  483. }
  484. return
  485. }
  486. var hasher = sha256.New()
  487. func hashFactorWithContext(f factor, currentCircuit *Circuit) []byte {
  488. hasher.Reset()
  489. hasher.Write([]byte(f.name))
  490. hasher.Write([]byte(currentCircuit.currentOutputName()))
  491. return hasher.Sum(nil)
  492. }
  493. func hashTogether(a, b []byte) []byte {
  494. hasher.Reset()
  495. hasher.Write(a)
  496. hasher.Write(b)
  497. return hasher.Sum(nil)
  498. }