diff --git a/README.md b/README.md index d46eeb2..788cdef 100644 --- a/README.md +++ b/README.md @@ -1,173 +1,40 @@ -# go-snark [![Go Report Card](https://goreportcard.com/badge/github.com/arnaucube/go-snark)](https://goreportcard.com/report/github.com/arnaucube/go-snark) - -zkSNARK library implementation in Go - - -- `Succinct Non-Interactive Zero Knowledge for a von Neumann Architecture`, Eli Ben-Sasson, Alessandro Chiesa, Eran Tromer, Madars Virza https://eprint.iacr.org/2013/879.pdf -- `Pinocchio: Nearly practical verifiable computation`, Bryan Parno, Craig Gentry, Jon Howell, Mariana Raykova https://eprint.iacr.org/2013/279.pdf ## Caution, Warning -UNDER CONSTRUCTION! - -Implementation of the zkSNARK [Pinocchio protocol](https://eprint.iacr.org/2013/279.pdf) from scratch in Go to understand the concepts. Do not use in production. +Fork UNDER CONSTRUCTION! Will ask for merge soon -This fork aims to extend its functionalities s.t. one can prove set-membership in zero knowledge. Current implementation status: -- [x] Finite Fields (1, 2, 6, 12) operations -- [x] G1 and G2 curve operations -- [x] BN128 Pairing (to be replaced with less unsecure curve) -- [x] circuit code compiler - - [ ] code to flat code (improve circuit compiler) (in progress) - - [x] flat code compiler -- [x] circuit to R1CS with gate reduction optimisation -- [x] polynomial operations -- [x] R1CS to QAP -- [x] generate trusted setup -- [x] generate proofs -- [x] verify proofs with BN128 pairing +- [x] extended circuit code compiler - [x] move witness calculation outside the setup phase -- [ ] Groth16 (in progress) -- [ ] multiple optimizations - - -## Usage -- [![GoDoc](https://godoc.org/github.com/arnaucube/go-snark?status.svg)](https://godoc.org/github.com/arnaucube/go-snark) zkSnark -- [![GoDoc](https://godoc.org/github.com/arnaucube/go-snark/bn128?status.svg)](https://godoc.org/github.com/arnaucube/go-snark/bn128) bn128 (more details: https://github.com/arnaucube/go-snark/tree/master/bn128) -- [![GoDoc](https://godoc.org/github.com/arnaucube/go-snark/fields?status.svg)](https://godoc.org/github.com/arnaucube/go-snark/fields) Finite Fields operations -- [![GoDoc](https://godoc.org/github.com/arnaucube/go-snark/r1csqap?status.svg)](https://godoc.org/github.com/arnaucube/go-snark/r1csqap) R1CS to QAP (more details: https://github.com/arnaucube/go-snark/tree/master/r1csqap) -- [![GoDoc](https://godoc.org/github.com/arnaucube/go-snark/circuitcompiler?status.svg)](https://godoc.org/github.com/arnaucube/go-snark/circuitcompiler) Circuit Compiler - -### CLI usage - -#### Compile circuit -Having a circuit file `test.circuit`: -``` -func test(private s0, public s1): - s2 = s0 * s0 - s3 = s2 * s0 - s4 = s3 + s0 - s5 = s4 + 5 - equals(s1, s5) - out = 1 * 1 -``` -And a private inputs file `privateInputs.json` -``` -[ - 3 -] -``` -And a public inputs file `publicInputs.json` -``` -[ - 35 -] -``` - -In the command line, execute: -``` -> ./go-snark-cli compile test.circuit -``` - -This will output the `compiledcircuit.json` file. - -#### Trusted Setup -Having the `compiledcircuit.json`, now we can generate the `TrustedSetup`: -``` -> ./go-snark-cli trustedsetup -``` -This will create the file `trustedsetup.json` with the TrustedSetup data, and also a `toxic.json` file, with the parameters to delete from the `Trusted Setup`. - - -#### Generate Proofs -Assumming that we have the `compiledcircuit.json`, `trustedsetup.json`, `privateInputs.json` and the `publicInputs.json` we can now generate the `Proofs` with the following command: -``` -> ./go-snark-cli genproofs -``` - -This will store the file `proofs.json`, that contains all the SNARK proofs. - -#### Verify Proofs -Having the `proofs.json`, `compiledcircuit.json`, `trustedsetup.json` `publicInputs.json` files, we can now verify the `Pairings` of the proofs, in order to verify the proofs. -``` -> ./go-snark-cli verify -``` -This will return a `true` if the proofs are verified, or a `false` if the proofs are not verified. - +- [x] fixed hard bugs ### Library usage Warning: not finished. -Example: +Working example of gate-reduction and code parsing: ```go -// compile circuit and get the R1CS -flatCode := ` -func test(private s0, public s1): - s2 = s0 * s0 - s3 = s2 * s0 - s4 = s3 + s0 - s5 = s4 + 5 - equals(s1, s5) - out = 1 * 1 -` - -// parse the code -parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) -circuit, err := parser.Parse() -assert.Nil(t, err) -fmt.Println(circuit) - - -b3 := big.NewInt(int64(3)) -privateInputs := []*big.Int{b3} -b35 := big.NewInt(int64(35)) -publicSignals := []*big.Int{b35} - -// witness -w, err := circuit.CalculateWitness(privateInputs, publicSignals) -assert.Nil(t, err) -fmt.Println("witness", w) - -// now we have the witness: -// w = [1 35 3 9 27 30 35 1] - -// flat code to R1CS -fmt.Println("generating R1CS from flat code") -a, b, c := circuit.GenerateR1CS() - -/* -now we have the R1CS from the circuit: -a: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] -b: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] -c: [[0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]] -*/ - - -alphas, betas, gammas, _ := snark.Utils.PF.R1CSToQAP(a, b, c) - - -ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) - -// calculate trusted setup -setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas) - -hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) - -proof, err := GenerateProofs(*circuit, setup, w, px) - -b35Verif := big.NewInt(int64(35)) -publicSignalsVerif := []*big.Int{b35Verif} -assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true)) -``` - - - -## Test -``` -go test ./... -v +def do(x): + e = x * 5 + b = e * 6 + c = b * 7 + f = c * 1 + d = c * f + out = d * mul(d,e) + + def doSomethingElse(x ,k): + z = k * x + out = do(x) + mul(x,z) + + def main(x,z): + out = do(z) + doSomethingElse(x,x) + + def mul(a,b): + out = a * b +``` +R1CS Output: +```go +[[0 0 210 0 0 0 0 0 0 0 0 0 0 0] [0 0 210 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0 0 0 0 0 0 0]] +[[0 0 210 0 0 0 0 0 0 0 0 0 0 0] [0 0 210 0 0 0 0 0 0 0 0 0 0 0] [0 0 5 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0 0 0] [0 5 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0 0 0 1 0 1 0]] +[[0 0 0 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 0 0 0 0 0 0 1]] ``` - ---- - - -Thanks to [@jbaylina](https://github.com/jbaylina), [@bellesmarta](https://github.com/bellesmarta), [@adriamb](https://github.com/adriamb) for their explanations that helped to understand this a little bit. Also thanks to [@vbuterin](https://github.com/vbuterin) for all the published articles explaining the zkSNARKs. +Note that we only need 11 multiplication Gates instead of 16 diff --git a/circuitcompiler/Programm.go b/circuitcompiler/Programm.go index ccd6e13..eb6ff8c 100644 --- a/circuitcompiler/Programm.go +++ b/circuitcompiler/Programm.go @@ -1,10 +1,12 @@ package circuitcompiler import ( + "crypto/sha256" "fmt" "github.com/mottla/go-snark/bn128" "github.com/mottla/go-snark/fields" "github.com/mottla/go-snark/r1csqap" + "hash" "math/big" "sync" ) @@ -19,8 +21,9 @@ type Program struct { functions map[string]*Circuit globalInputs []string arithmeticEnvironment utils //find a better name - - R1CS struct { + sha256Hasher hash.Hash + computedInContext map[string]map[string]string + R1CS struct { A [][]*big.Int B [][]*big.Int C [][]*big.Int @@ -45,10 +48,12 @@ func (p *Program) BuildConstraintTrees() { p.getMainCircuit().gateMap[mainRoot.value.Out] = mainRoot } + //for _, in := range p.getMainCircuit().Inputs { + // p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs)) + //} for _, in := range p.getMainCircuit().Inputs { - p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs)) + p.globalInputs = append(p.globalInputs, in) } - var wg = sync.WaitGroup{} for _, circuit := range p.functions { @@ -101,67 +106,86 @@ func (c *Circuit) buildTree(g *gate) { } func (p *Program) ReduceCombinedTree() (orderedmGates []gate) { - mGatesUsed := make(map[string]bool) + //mGatesUsed := make(map[string]bool) orderedmGates = []gate{} - p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, mGatesUsed, &orderedmGates, false, false) + p.computedInContext = make(map[string]map[string]string) + rootHash := []byte{} + p.computedInContext[string(rootHash)] = make(map[string]string) + p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false) return orderedmGates } -func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, root *gate, mGatesUsed map[string]bool, orderedmGates *[]gate, negate bool, inverse bool) (variableEnd bool) { +func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) { - if root.OperationType() == IN { - return true - } + if node.OperationType() == CONST { + b1, v1 := isValue(node.value.Out) + if !b1 { + panic("not a constant") + } + mul := [2]int{v1, 1} + if invert { + mul = [2]int{1, v1} - if root.OperationType() == CONST { - return false + } + return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false } - if root.OperationType() == FUNC { - nextContext := p.extendedFunctionRenamer(currentCircuit, root.value) + if node.OperationType() == FUNC { + nextContext := p.extendedFunctionRenamer(currentCircuit, node.value) currentCircuit = nextContext - root = nextContext.root + node = nextContext.root + hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName())) + if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex { + p.computedInContext[string(hashTraceBuildup)] = make(map[string]string) + } + } - originOfVariable := p.functions[getContextFromVariable(root.value.Out)] - if _, alreadyComputed := mGatesUsed[composeNewFunction(root.value.Out, originOfVariable.currentOutputs())]; alreadyComputed { - return true + if node.OperationType() == IN { + fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}} + hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out)) + return []factor{fac}, hashTraceBuildup, true } - variableEnd = p.r1CSRecursiveBuild(currentCircuit, root.left, mGatesUsed, orderedmGates, negate, inverse) + if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex { + fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}} + hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out)) + return []factor{fac}, hashTraceBuildup, true + } - cons := p.r1CSRecursiveBuild(currentCircuit, root.right, mGatesUsed, orderedmGates, Xor(negate, root.value.negate), Xor(inverse, root.value.invert)) + leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert) - if root.OperationType() == MULTIPLY { + rightFactors, rightHash, cons := p.r1CSRecursiveBuild(currentCircuit, node.right, hashTraceBuildup, orderedmGates, Xor(negate, node.value.negate), Xor(invert, node.value.invert)) - if !(variableEnd && cons) && !root.value.invert && root != p.getMainCircuit().root { - return variableEnd || cons - } - root.leftIns = p.collectFactors(currentCircuit, root.left, mGatesUsed, false, false) - //if root.left.value.Out== root.right.value.Out{ - // //note this is not a full copy, but shouldnt be a problem - // root.rightIns= root.leftIns - //}else{ - // collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert)) - //} - //root.rightIns = collectAtomsInSubtree3(root.right, mGatesUsed, Xor(negate, root.value.negate), Xor(inverse, root.value.invert)) - root.rightIns = p.collectFactors(currentCircuit, root.right, mGatesUsed, false, false) - root.index = len(mGatesUsed) - var nn = composeNewFunction(root.value.Out, originOfVariable.currentOutputs()) - //var nn = root.value.Out - //if _, ex := p.functions[root.value.Out]; ex { - // nn = currentCircuit.currentOutputName() - //} - if _, ex := mGatesUsed[nn]; ex { - panic(fmt.Sprintf("told ya so %v", nn)) + if node.OperationType() == MULTIPLY { + + if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root { + //if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root { + return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons } - mGatesUsed[nn] = true - rootGate := cloneGate(root) - rootGate.value.Out = nn + rootGate := cloneGate(node) + rootGate.index = len(*orderedmGates) + rootGate.leftIns = leftFactors + rootGate.rightIns = rightFactors + out := hashTogether(leftHash, rightHash) + rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10]) + rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10]) + rootGate.value.Out = rootGate.value.Out + string(out[:10]) + p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out *orderedmGates = append(*orderedmGates, *rootGate) + + hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out)) + + return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true + } + + switch node.OperationType() { + case PLUS: + return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons + default: + panic("unexpected gate") } - return variableEnd || cons //TODO optimize if output is not a multipication gate } @@ -298,61 +322,6 @@ func addFactors(leftFactors, rightFactors []factor) []factor { return res } -func (p *Program) collectFactors(contextCircut *Circuit, node *gate, mGatesUsed map[string]bool, negate bool, invert bool) []factor { - - if node.OperationType() == CONST { - b1, v1 := isValue(node.value.Out) - if !b1 { - panic("not a constant") - } - if invert { - return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{1, v1}}} - } - return []factor{{typ: CONST, negate: negate, multiplicative: [2]int{v1, 1}}} - } - - if node.OperationType() == FUNC { - nextContext := p.extendedFunctionRenamer(contextCircut, node.value) - - //if _, ex := mGatesUsed[nextContext.currentOutputName()]; ex { - // return []factor{{typ: IN, name: nextContext.currentOutputName(), invert: invert, negate: negate, multiplicative: [2]int{1, 1}}} - //} - contextCircut = nextContext - node = nextContext.root - } - - originOfVariable := p.functions[getContextFromVariable(node.value.Out)] - if originOfVariable == nil { - fmt.Println("asdf") - } - lookingFOr := composeNewFunction(node.value.Out, originOfVariable.currentOutputs()) - - //if _, ex := mGatesUsed[node.value.Out]; ex { - // return []factor{{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}} - //} - - if node.OperationType() == IN { - return []factor{{typ: IN, name: lookingFOr, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}} - } - - if _, alreadyComputed := mGatesUsed[lookingFOr]; alreadyComputed { - return []factor{{typ: IN, name: lookingFOr, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}} - } - - leftFactors := p.collectFactors(contextCircut, node.left, mGatesUsed, negate, invert) - rightFactors := p.collectFactors(contextCircut, node.right, mGatesUsed, Xor(negate, node.value.negate), Xor(invert, node.value.invert)) - - switch node.OperationType() { - case MULTIPLY: - return mulFactors(leftFactors, rightFactors) - case PLUS: - return addFactors(leftFactors, rightFactors) - default: - panic("unexpected gate") - } - -} - //copies a gate neglecting its references to other gates func cloneGate(in *gate) (out *gate) { constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1} @@ -445,7 +414,12 @@ func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *C } func NewProgram() (p *Program) { - p = &Program{functions: map[string]*Circuit{}, globalInputs: []string{"one"}, arithmeticEnvironment: prepareUtils()} + p = &Program{ + functions: map[string]*Circuit{}, + globalInputs: []string{"one"}, + arithmeticEnvironment: prepareUtils(), + sha256Hasher: sha256.New(), + } return } @@ -606,3 +580,18 @@ func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) { return } + +var hasher = sha256.New() + +func hashFactorWithContext(f factor, currentCircuit *Circuit) []byte { + hasher.Reset() + hasher.Write([]byte(f.name)) + hasher.Write([]byte(currentCircuit.currentOutputName())) + return hasher.Sum(nil) +} +func hashTogether(a, b []byte) []byte { + hasher.Reset() + hasher.Write(a) + hasher.Write(b) + return hasher.Sum(nil) +} diff --git a/circuitcompiler/Programm_test.go b/circuitcompiler/Programm_test.go index 4e8c43c..8604846 100644 --- a/circuitcompiler/Programm_test.go +++ b/circuitcompiler/Programm_test.go @@ -2,31 +2,36 @@ package circuitcompiler import ( "fmt" + "github.com/stretchr/testify/assert" "math/big" "strings" "testing" ) -func TestProgramm_BuildConstraintTree(t *testing.T) { - line := "asdf asfd" - line = strings.TrimFunc(line, func(i rune) bool { return isWhitespace(i) }) - fmt.Println(line) +type InOut struct { + inputs []*big.Int + result *big.Int } -func TestNewProgramm(t *testing.T) { - //flat := ` - // - //def add(x ,k): - // z = k * x - // out = 6 + mul(x,z) - // - //def main(x,z): - // out = mul(x,z) - add(x,x) - // - //def mul(a,b): - // out = a * b - //` - flat := ` +type TraceCorrectnessTest struct { + code string + io []InOut +} + +var bigNumberResult1, _ = new(big.Int).SetString("2297704271284150716235246193843898764109352875", 10) +var bigNumberResult2, _ = new(big.Int).SetString("75263346540254220740876250", 10) + +var correctnesTest = []TraceCorrectnessTest{ + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(1729500084900343)), + }, { + inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + + result: bigNumberResult1, + }}, + code: ` def do(x): e = x * 5 b = e * 6 @@ -42,60 +47,112 @@ func TestNewProgramm(t *testing.T) { def main(x,z): out = do(z) + add(x,x) + def mul(a,b): + out = a * b + `, + }, + {io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(7))}, + result: big.NewInt(int64(4)), + }}, + code: ` + def mul(a,b): + out = a * b + + def main(a): + b = a * a + c = 4 - b + d = 5 * c + out = mul(d,c) / mul(b,b) + `, + }, + {io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))}, + result: big.NewInt(int64(22638)), + }, { + inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + result: bigNumberResult2, + }}, + code: ` + def main(a,b): + d = b + b + c = a * d + e = c - a + out = e * c + `, + }, + { + io: []InOut{{ + inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))}, + result: big.NewInt(int64(98441327276)), + }, { + inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))}, + result: big.NewInt(int64(8675445947220)), + }}, + code: ` + def main(a,b): + c = a + b + e = c - a + f = e + b + g = f + 2 + out = g * a + `, + }, +} + +func TestNewProgramm(t *testing.T) { + flat := ` + + def doSomething(x ,k): + z = k * x + out = 6 + mul(x,z) + + def main(x,z): + out = mul(x,z) - doSomething(x,x) + def mul(a,b): out = a * b ` - //flat := ` - //func mul(a,b): - // out = a * b - // - //func main(a,aa): - // b = a * a - // c = 4 - b - // d = 5 * c - // out = mul(d,c) / mul(b,b) - //` - //flat := ` - //func main(a,b): - // c = a + b - // e = c - a - // f = e + b - // g = f + 2 - // out = g * a - //` - parser := NewParser(strings.NewReader(flat)) - program, err := parser.Parse() + for _, test := range correctnesTest { + parser := NewParser(strings.NewReader(test.code)) + program, err := parser.Parse() - if err != nil { - panic(err) - } - fmt.Println("\n unreduced") - fmt.Println(flat) + if err != nil { + panic(err) + } + fmt.Println("\n unreduced") + fmt.Println(test.code) - program.BuildConstraintTrees() - for k, v := range program.functions { - fmt.Println(k) - PrintTree(v.root) - } + program.BuildConstraintTrees() + for k, v := range program.functions { + fmt.Println(k) + PrintTree(v.root) + } - fmt.Println("\nReduced gates") - //PrintTree(froots["mul"]) - gates := program.ReduceCombinedTree() - for _, g := range gates { - fmt.Printf("\n %v", g) - } + fmt.Println("\nReduced gates") + //PrintTree(froots["mul"]) + gates := program.ReduceCombinedTree() + for _, g := range gates { + fmt.Printf("\n %v", g) + } - fmt.Println("generating R1CS") - a, b, c := program.GenerateReducedR1CS(gates) - fmt.Println(a) - fmt.Println(b) - fmt.Println(c) - a1 := big.NewInt(int64(7)) - a2 := big.NewInt(int64(11)) - inputs := []*big.Int{a1, a2} - w := program.CalculateWitness(inputs) - fmt.Println("witness") - fmt.Println(w) + fmt.Println("\n generating R1CS") + a, b, c := program.GenerateReducedR1CS(gates) + fmt.Println(a) + fmt.Println(b) + fmt.Println(c) + + for _, io := range test.io { + inputs := io.inputs + fmt.Println("input") + fmt.Println(inputs) + w := program.CalculateWitness(inputs) + fmt.Println("witness") + fmt.Println(w) + assert.Equal(t, io.result, w[len(w)-1]) + } + + } } diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index bb0d267..d6e77ae 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -70,7 +70,6 @@ func newCircuit(name string) *Circuit { func (p *Program) addFunction(constraint *Constraint) (c *Circuit) { name := constraint.Out - fmt.Println("try to add function ", name) b, name2, _ := isFunction(name) if !b {