You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

487 lines
14 KiB

  1. package circuitcompiler
  2. import (
  3. "crypto/sha256"
  4. "fmt"
  5. "github.com/arnaucube/go-snark/bn128"
  6. "github.com/arnaucube/go-snark/fields"
  7. "github.com/arnaucube/go-snark/r1csqap"
  8. "math/big"
  9. "sync"
  10. )
  11. type utils struct {
  12. Bn bn128.Bn128
  13. FqR fields.Fq
  14. PF r1csqap.PolynomialField
  15. }
  16. type R1CS struct {
  17. A [][]*big.Int
  18. B [][]*big.Int
  19. C [][]*big.Int
  20. }
  21. type MultiplicationGateSignature struct {
  22. identifier string
  23. commonExtracted [2]int //if the mgate had a extractable factor, it will be stored here
  24. }
  25. type Program struct {
  26. functions map[string]*Circuit
  27. globalInputs []string
  28. globalOutput map[string]bool
  29. arithmeticEnvironment utils //find a better name
  30. //key 1: the hash chain indicating from where the variable is called H( H(main(a,b)) , doSomething(x,z) ), where H is a hash function.
  31. //value 1 : map
  32. // with key variable name
  33. // with value variable name + hash Chain
  34. //this datastructure is nice but maybe ill replace it later with something less confusing
  35. //it serves the elementary purpose of not computing a variable a second time.
  36. //it boosts parse time
  37. computedInContext map[string]map[string]MultiplicationGateSignature
  38. //to reduce the number of multiplication gates, we store each factor signature, and the variable name,
  39. //so each time a variable is computed, that happens to have the very same factors, we reuse the former
  40. //it boost setup and proof time
  41. computedFactors map[string]MultiplicationGateSignature
  42. }
  43. //returns the cardinality of all main inputs + 1 for the "one" signal
  44. func (p *Program) GlobalInputCount() int {
  45. return len(p.globalInputs)
  46. }
  47. //returns the cardinaltiy of the output signals. Current only 1 output possible
  48. func (p *Program) GlobalOutputCount() int {
  49. return len(p.globalOutput)
  50. }
  51. func (p *Program) PrintContraintTrees() {
  52. for k, v := range p.functions {
  53. fmt.Println(k)
  54. PrintTree(v.root)
  55. }
  56. }
  57. func (p *Program) BuildConstraintTrees() {
  58. mainRoot := p.getMainCircuit().root
  59. //if our programs last operation is not a multiplication gate, we need to introduce on
  60. if mainRoot.value.Op&(MINUS|PLUS) != 0 {
  61. newOut := Constraint{Out: "out", V1: "1", V2: "out2", Op: MULTIPLY}
  62. p.getMainCircuit().addConstraint(&newOut)
  63. mainRoot.value.Out = "main@out2"
  64. p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot
  65. }
  66. for _, in := range p.getMainCircuit().Inputs {
  67. p.globalInputs = append(p.globalInputs, in)
  68. }
  69. var wg = sync.WaitGroup{}
  70. //we build the parse trees concurrently! because we can! go rocks
  71. for _, circuit := range p.functions {
  72. wg.Add(1)
  73. //interesting: if circuit is not passed as argument, the program fails. duno why..
  74. go func(c *Circuit) {
  75. c.buildTree(c.root)
  76. wg.Done()
  77. }(circuit)
  78. }
  79. wg.Wait()
  80. return
  81. }
  82. func (c *Circuit) buildTree(g *gate) {
  83. if _, ex := c.gateMap[g.value.Out]; ex {
  84. if g.OperationType()&(IN|CONST) != 0 {
  85. return
  86. }
  87. } else {
  88. panic(fmt.Sprintf("undefined variable %s", g.value.Out))
  89. }
  90. if g.OperationType() == FUNC {
  91. for _, in := range g.value.Inputs {
  92. if gate, ex := c.gateMap[in]; ex {
  93. g.funcInputs = append(g.funcInputs, gate)
  94. c.buildTree(gate)
  95. } else {
  96. panic(fmt.Sprintf("undefined argument %s", g.value.V1))
  97. }
  98. }
  99. return
  100. }
  101. if constr, ex := c.gateMap[g.value.V1]; ex {
  102. g.left = constr
  103. c.buildTree(g.left)
  104. } else {
  105. panic(fmt.Sprintf("undefined value %s", g.value.V1))
  106. }
  107. if constr, ex := c.gateMap[g.value.V2]; ex {
  108. g.right = constr
  109. c.buildTree(g.right)
  110. } else {
  111. panic(fmt.Sprintf("undefined value %s", g.value.V2))
  112. }
  113. }
  114. func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
  115. orderedmGates = []gate{}
  116. p.computedInContext = make(map[string]map[string]MultiplicationGateSignature)
  117. p.computedFactors = make(map[string]MultiplicationGateSignature)
  118. rootHash := make([]byte, 10)
  119. p.computedInContext[string(rootHash)] = make(map[string]MultiplicationGateSignature)
  120. p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
  121. return orderedmGates
  122. }
  123. //recursively walks through the parse tree to create a list of all
  124. //multiplication gates needed for the QAP construction
  125. //Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
  126. func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs factors, hashTraceResult []byte, variableEnd bool) {
  127. if node.OperationType() == CONST {
  128. b1, v1 := isValue(node.value.Out)
  129. if !b1 {
  130. panic("not a constant")
  131. }
  132. mul := [2]int{v1, 1}
  133. if invert {
  134. mul = [2]int{1, v1}
  135. }
  136. return factors{{typ: CONST, negate: negate, multiplicative: mul}}, hashTraceBuildup, false
  137. }
  138. if node.OperationType() == FUNC {
  139. nextContext := p.extendedFunctionRenamer(currentCircuit, node.value)
  140. currentCircuit = nextContext
  141. node = nextContext.root
  142. hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
  143. if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
  144. p.computedInContext[string(hashTraceBuildup)] = make(map[string]MultiplicationGateSignature)
  145. }
  146. }
  147. if node.OperationType() == IN {
  148. fac := &factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
  149. return factors{fac}, hashTraceBuildup, true
  150. }
  151. if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
  152. fac := &factor{typ: IN, name: out.identifier, invert: invert, negate: negate, multiplicative: out.commonExtracted}
  153. return factors{fac}, hashTraceBuildup, true
  154. }
  155. leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
  156. rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert))
  157. if node.OperationType() == MULTIPLY {
  158. if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
  159. return mulFactors(leftFactors, rightFactors), hashTraceBuildup, variableEnd || cons
  160. }
  161. sig, newLef, newRigh := factorsSignature(leftFactors, rightFactors)
  162. if out, ex := p.computedFactors[sig.identifier]; ex {
  163. return factors{{typ: IN, name: out.identifier, invert: invert, negate: negate, multiplicative: sig.commonExtracted}}, hashTraceBuildup, true
  164. }
  165. rootGate := cloneGate(node)
  166. //rootGate := node
  167. rootGate.index = len(*orderedmGates)
  168. if p.getMainCircuit().root == node {
  169. newLef = mulFactors(newLef, factors{&factor{typ: CONST, multiplicative: sig.commonExtracted}})
  170. }
  171. rootGate.leftIns = newLef
  172. rootGate.rightIns = newRigh
  173. out := hashTogether(leftHash, rightHash)
  174. rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
  175. rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
  176. //note we only check for existence, but not for truth.
  177. //global outputs do not require a hash identifier, since they are unique
  178. if _, ex := p.globalOutput[rootGate.value.Out]; !ex {
  179. rootGate.value.Out = rootGate.value.Out + string(out[:10])
  180. }
  181. p.computedInContext[string(hashTraceBuildup)][node.value.Out] = MultiplicationGateSignature{identifier: rootGate.value.Out, commonExtracted: sig.commonExtracted}
  182. p.computedFactors[sig.identifier] = MultiplicationGateSignature{identifier: rootGate.value.Out, commonExtracted: sig.commonExtracted}
  183. *orderedmGates = append(*orderedmGates, *rootGate)
  184. return factors{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: sig.commonExtracted}}, hashTraceBuildup, true
  185. }
  186. switch node.OperationType() {
  187. case PLUS:
  188. return addFactors(leftFactors, rightFactors), hashTraceBuildup, variableEnd || cons
  189. default:
  190. panic("unexpected gate")
  191. }
  192. }
  193. //copies a gate neglecting its references to other gates
  194. func cloneGate(in *gate) (out *gate) {
  195. constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
  196. nRightins := in.rightIns.clone()
  197. nLeftInst := in.leftIns.clone()
  198. return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
  199. }
  200. func (p *Program) getMainCircuit() *Circuit {
  201. return p.functions["main"]
  202. }
  203. func prepareUtils() utils {
  204. bn, err := bn128.NewBn128()
  205. if err != nil {
  206. panic(err)
  207. }
  208. // new Finite Field
  209. fqR := fields.NewFq(bn.R)
  210. // new Polynomial Field
  211. pf := r1csqap.NewPolynomialField(fqR)
  212. return utils{
  213. Bn: bn,
  214. FqR: fqR,
  215. PF: pf,
  216. }
  217. }
  218. func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *Constraint) (nextContext *Circuit) {
  219. if constraint.Op != FUNC {
  220. panic("not a function")
  221. }
  222. //if _, ex := contextCircuit.gateMap[constraint.Out]; !ex {
  223. // panic("constraint must be within the contextCircuit circuit")
  224. //}
  225. b, n, _ := isFunction(constraint.Out)
  226. if !b {
  227. panic("not expected")
  228. }
  229. if newContext, v := p.functions[n]; v {
  230. //am i certain that constraint.inputs is alwazs equal to n??? me dont like it
  231. for i, argument := range constraint.Inputs {
  232. isConst, _ := isValue(argument)
  233. if isConst {
  234. continue
  235. }
  236. isFunc, _, _ := isFunction(argument)
  237. if isFunc {
  238. panic("functions as arguments no supported yet")
  239. //p.extendedFunctionRenamer(contextCircuit,)
  240. }
  241. //at this point I assert that argument is a variable. This can become troublesome later
  242. //first we get the circuit in which the argument was created
  243. inputOriginCircuit := p.functions[getContextFromVariable(argument)]
  244. //we pick the gate that has the argument as output
  245. if gate, ex := inputOriginCircuit.gateMap[argument]; ex {
  246. //we pick the old circuit inputs and let them now reference the same as the argument gate did,
  247. oldGate := newContext.gateMap[newContext.Inputs[i]]
  248. //we take the old gate which was nothing but a input
  249. //and link this input to its constituents coming from the calling contextCircuit.
  250. //i think this is pretty neat
  251. oldGate.value = gate.value
  252. oldGate.right = gate.right
  253. oldGate.left = gate.left
  254. } else {
  255. panic("not expected")
  256. }
  257. }
  258. //newContext.renameInputs(constraint.Inputs)
  259. return newContext
  260. }
  261. return nil
  262. }
  263. func NewProgram() (p *Program) {
  264. p = &Program{
  265. functions: map[string]*Circuit{},
  266. globalInputs: []string{"one"},
  267. globalOutput: map[string]bool{"main": true},
  268. arithmeticEnvironment: prepareUtils(),
  269. }
  270. return
  271. }
  272. // GenerateR1CS generates the R1CS polynomials from the Circuit
  273. func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
  274. // from flat code to R1CS
  275. offset := len(p.globalInputs)
  276. // one + in1 +in2+... + gate1 + gate2 .. + out
  277. size := offset + len(mGates)
  278. indexMap := make(map[string]int)
  279. for i, v := range p.globalInputs {
  280. indexMap[v] = i
  281. }
  282. for k, _ := range p.globalOutput {
  283. indexMap[k] = len(indexMap)
  284. }
  285. for _, v := range mGates {
  286. if _, ex := indexMap[v.value.Out]; !ex {
  287. indexMap[v.value.Out] = len(indexMap)
  288. }
  289. }
  290. for _, g := range mGates {
  291. if g.OperationType() == MULTIPLY {
  292. aConstraint := r1csqap.ArrayOfBigZeros(size)
  293. bConstraint := r1csqap.ArrayOfBigZeros(size)
  294. cConstraint := r1csqap.ArrayOfBigZeros(size)
  295. insertValue := func(val *factor, arr []*big.Int) {
  296. if val.typ != CONST {
  297. if _, ex := indexMap[val.name]; !ex {
  298. panic(fmt.Sprintf("%v index not found!!!", val.name))
  299. }
  300. }
  301. value := new(big.Int).Add(new(big.Int), fractionToField(val.multiplicative))
  302. if val.negate {
  303. value.Neg(value)
  304. }
  305. //not that index is 0 if its a constant, since 0 is the map default if no entry was found
  306. arr[indexMap[val.name]] = value
  307. }
  308. for _, val := range g.leftIns {
  309. insertValue(val, aConstraint)
  310. }
  311. for _, val := range g.rightIns {
  312. insertValue(val, bConstraint)
  313. }
  314. cConstraint[indexMap[g.value.Out]] = big.NewInt(int64(1))
  315. if g.value.invert {
  316. tmp := aConstraint
  317. aConstraint = cConstraint
  318. cConstraint = tmp
  319. }
  320. r1CS.A = append(r1CS.A, aConstraint)
  321. r1CS.B = append(r1CS.B, bConstraint)
  322. r1CS.C = append(r1CS.C, cConstraint)
  323. } else {
  324. panic("not a m gate")
  325. }
  326. }
  327. return
  328. }
  329. var Utils = prepareUtils()
  330. func fractionToField(in [2]int) *big.Int {
  331. return Utils.FqR.Mul(big.NewInt(int64(in[0])), Utils.FqR.Inverse(big.NewInt(int64(in[1]))))
  332. }
  333. //Calculates the witness (program trace) given some input
  334. //asserts that R1CS has been computed and is stored in the program p memory calling this function
  335. func CalculateWitness(input []*big.Int, r1cs R1CS) (witness []*big.Int) {
  336. witness = r1csqap.ArrayOfBigZeros(len(r1cs.A[0]))
  337. set := make([]bool, len(witness))
  338. witness[0] = big.NewInt(int64(1))
  339. set[0] = true
  340. for i := range input {
  341. witness[i+1] = input[i]
  342. set[i+1] = true
  343. }
  344. zero := big.NewInt(int64(0))
  345. for i := 0; i < len(r1cs.A); i++ {
  346. gatesLeftInputs := r1cs.A[i]
  347. gatesRightInputs := r1cs.B[i]
  348. gatesOutputs := r1cs.C[i]
  349. sumLeft := big.NewInt(int64(0))
  350. sumRight := big.NewInt(int64(0))
  351. sumOut := big.NewInt(int64(0))
  352. index := -1
  353. division := false
  354. for j, val := range gatesLeftInputs {
  355. if val.Cmp(zero) != 0 {
  356. if !set[j] {
  357. index = j
  358. division = true
  359. break
  360. }
  361. sumLeft.Add(sumLeft, new(big.Int).Mul(val, witness[j]))
  362. }
  363. }
  364. for j, val := range gatesRightInputs {
  365. if val.Cmp(zero) != 0 {
  366. sumRight.Add(sumRight, new(big.Int).Mul(val, witness[j]))
  367. }
  368. }
  369. for j, val := range gatesOutputs {
  370. if val.Cmp(zero) != 0 {
  371. if !set[j] {
  372. if index != -1 {
  373. panic("invalid R1CS form")
  374. }
  375. index = j
  376. break
  377. }
  378. sumOut.Add(sumOut, new(big.Int).Mul(val, witness[j]))
  379. }
  380. }
  381. if !division {
  382. set[index] = true
  383. witness[index] = new(big.Int).Mul(sumLeft, sumRight)
  384. } else {
  385. b := sumRight.Int64()
  386. c := sumOut.Int64()
  387. set[index] = true
  388. //TODO replace with proper multiplication of b^-1 within the finite field
  389. witness[index] = big.NewInt(c / b)
  390. //Utils.FqR.Mul(sumOut, Utils.FqR.Inverse(sumRight))
  391. }
  392. }
  393. return
  394. }
  395. var hasher = sha256.New()
  396. func hashTogether(a, b []byte) []byte {
  397. hasher.Reset()
  398. hasher.Write(a)
  399. hasher.Write(b)
  400. return hasher.Sum(nil)
  401. }