diff --git a/README.md b/README.md index e0c6c5e..a8c53f3 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ Minimal complete flow implementation: - [x] verify proofs with BN128 pairing Improvements from the minimal implementation: +- [x] allow to call functions in circuits language - [ ] allow `import` in circuits language - [ ] allow `for` in circuits language -- [ ] code to flat code (improve circuit compiler) - [ ] move witness values calculation outside the setup phase - [ ] Groth16 - [ ] multiple optimizations @@ -54,9 +54,13 @@ In this example we will follow the equation example from [Vitalik](https://mediu #### Compile circuit Having a circuit file `test.circuit`: ``` -func test(private s0, public s1): - s2 = s0 * s0 - s3 = s2 * s0 +func exp3(private a): + b = a * a + c = a * b + return c + +func main(private s0, public s1): + s3 = exp3(s0) s4 = s3 + s0 s5 = s4 + 5 equals(s1, s5) @@ -112,9 +116,13 @@ Example: ```go // compile circuit and get the R1CS flatCode := ` -func test(private s0, public s1): - s2 = s0 * s0 - s3 = s2 * s0 +func exp3(private a): + b = a * a + c = a * b + return c + +func main(private s0, public s1): + s3 = exp3(s0) s4 = s3 + s0 s5 = s4 + 5 equals(s1, s5) @@ -147,8 +155,8 @@ a, b, c := circuit.GenerateR1CS() /* now we have the R1CS from the circuit: -a: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] -b: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] +a: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] +b: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] c: [[0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]] */ diff --git a/circuitcompiler/circuit_test.go b/circuitcompiler/circuit_test.go index 485406b..dcf6dab 100644 --- a/circuitcompiler/circuit_test.go +++ b/circuitcompiler/circuit_test.go @@ -11,7 +11,7 @@ import ( func TestCircuitParser(t *testing.T) { // y = x^3 + x + 5 flat := ` - func test(private s0, public s1): + func main(private s0, public s1): s2 = s0 * s0 s3 = s2 * s0 s4 = s3 + s0 @@ -86,3 +86,89 @@ func TestCircuitParser(t *testing.T) { assert.Equal(t, len(circuit.PublicInputs), 1) assert.Equal(t, len(circuit.PrivateInputs), 1) } + +func TestCircuitWithFuncCallsParser(t *testing.T) { + // y = x^3 + x + 5 + code := ` + func exp3(private a): + b = a * a + c = a * b + return c + func sum(private a, private b): + c = a + b + return c + + func main(private s0, public s1): + s3 = exp3(s0) + s4 = sum(s3, s0) + s5 = s4 + 5 + equals(s1, s5) + out = 1 * 1 + ` + parser := NewParser(strings.NewReader(code)) + circuit, err := parser.Parse() + assert.Nil(t, err) + + // flat code to R1CS + a, b, c := circuit.GenerateR1CS() + assert.Equal(t, "s0", circuit.PrivateInputs[0]) + assert.Equal(t, "s1", circuit.PublicInputs[0]) + + assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, circuit.Signals) + + // expected result + b0 := big.NewInt(int64(0)) + b1 := big.NewInt(int64(1)) + b5 := big.NewInt(int64(5)) + aExpected := [][]*big.Int{ + []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0}, + []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0}, + []*big.Int{b0, b0, b1, b0, b1, b0, b0, b0}, + []*big.Int{b5, b0, b0, b0, b0, b1, b0, b0}, + []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0}, + []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0}, + } + bExpected := [][]*big.Int{ + []*big.Int{b0, b0, b1, b0, b0, b0, b0, b0}, + []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0, b0, b0}, + } + cExpected := [][]*big.Int{ + []*big.Int{b0, b0, b0, b1, b0, b0, b0, b0}, + []*big.Int{b0, b0, b0, b0, b1, b0, b0, b0}, + []*big.Int{b0, b0, b0, b0, b0, b1, b0, b0}, + []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0}, + []*big.Int{b0, b1, b0, b0, b0, b0, b0, b0}, + []*big.Int{b0, b0, b0, b0, b0, b0, b1, b0}, + []*big.Int{b0, b0, b0, b0, b0, b0, b0, b1}, + } + + assert.Equal(t, aExpected, a) + assert.Equal(t, bExpected, b) + assert.Equal(t, cExpected, c) + + b3 := big.NewInt(int64(3)) + privateInputs := []*big.Int{b3} + b35 := big.NewInt(int64(35)) + publicInputs := []*big.Int{b35} + // Calculate Witness + w, err := circuit.CalculateWitness(privateInputs, publicInputs) + assert.Nil(t, err) + b9 := big.NewInt(int64(9)) + b27 := big.NewInt(int64(27)) + b30 := big.NewInt(int64(30)) + wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1} + assert.Equal(t, wExpected, w) + + // circuitJson, _ := json.Marshal(circuit) + // fmt.Println("circuit:", string(circuitJson)) + + assert.Equal(t, circuit.NPublic, 1) + assert.Equal(t, len(circuit.PublicInputs), 1) + assert.Equal(t, len(circuit.PrivateInputs), 1) +} diff --git a/circuitcompiler/parser.go b/circuitcompiler/parser.go index 70f074b..6bbf42b 100644 --- a/circuitcompiler/parser.go +++ b/circuitcompiler/parser.go @@ -6,6 +6,7 @@ import ( "io" "os" "regexp" + "strconv" "strings" ) @@ -68,6 +69,12 @@ func (p *Parser) parseLine() (*Constraint, error) { if err != nil { return c, err } + // get func name + fName := strings.Split(line, "(")[0] + fName = strings.Replace(fName, " ", "", -1) + fName = strings.Replace(fName, " ", "", -1) + c.V1 = fName // so, the name of the func will be in c.V1 + // read string inside ( ) rgx := regexp.MustCompile(`\((.*?)\)`) insideParenthesis := rgx.FindStringSubmatch(line) @@ -105,20 +112,47 @@ func (p *Parser) parseLine() (*Constraint, error) { c.V2 = params[1] return c, nil } - // if c.Literal == "out" { - // // TODO - // return c, nil - // } + if c.Literal == "return" { + _, varToReturn := p.scanIgnoreWhitespace() + c.Out = varToReturn + return c, nil + } _, lit = p.scanIgnoreWhitespace() // skip = c.Literal += lit // v1 _, lit = p.scanIgnoreWhitespace() + + // check if lit is a name of a func that we have declared + if _, ok := circuits[lit]; ok { + // if inside, is calling a declared function + c.Literal = "call" + c.Op = lit // c.Op handles the name of the function called + // put the inputs of the call into the c.PrivateInputs + // format: `funcname(a, b)` + line, err := p.s.r.ReadString(')') + if err != nil { + fmt.Println("ERR", err) + return c, err + } + // read string inside ( ) + rgx := regexp.MustCompile(`\((.*?)\)`) + insideParenthesis := rgx.FindStringSubmatch(line) + varsString := strings.Replace(insideParenthesis[1], " ", "", -1) + params := strings.Split(varsString, ",") + c.PrivateInputs = params + return c, nil + + } + c.V1 = lit c.Literal += lit // operator _, lit = p.scanIgnoreWhitespace() + if lit == "(" { + panic(errors.New("using not declared function")) + } c.Op = lit c.Literal += lit // v2 @@ -150,39 +184,67 @@ func addToArrayIfNotExist(arr []string, elem string) []string { return arr } +func subsIfInMap(original string, m map[string]string) string { + if v, ok := m[original]; ok { + return v + } + return original +} + +var circuits map[string]*Circuit + // Parse parses the lines and returns the compiled Circuit func (p *Parser) Parse() (*Circuit, error) { - circuit := &Circuit{} - circuit.Signals = append(circuit.Signals, "one") + // funcsMap is a map holding the functions names and it's content as Circuit + circuits = make(map[string]*Circuit) + mainExist := false + circuits["main"] = &Circuit{} + callsCount := 0 + + circuits["main"].Signals = append(circuits["main"].Signals, "one") nInputs := 0 + currCircuit := "" for { constraint, err := p.parseLine() if err != nil { break } if constraint.Literal == "func" { + // the name of the func is in constraint.V1 + // check if the name of func is main + if constraint.V1 != "main" { + currCircuit = constraint.V1 + circuits[currCircuit] = &Circuit{} + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constraint) + continue + } + currCircuit = "main" + mainExist = true + // l, _ := json.Marshal(constraint) + // fmt.Println(string(l)) + // one constraint for each input for _, in := range constraint.PublicInputs { newConstr := &Constraint{ Op: "in", Out: in, } - circuit.Constraints = append(circuit.Constraints, *newConstr) + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *newConstr) nInputs++ - circuit.Signals = addToArrayIfNotExist(circuit.Signals, in) - circuit.NPublic++ + circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, in) + circuits[currCircuit].NPublic++ } for _, in := range constraint.PrivateInputs { newConstr := &Constraint{ Op: "in", Out: in, } - circuit.Constraints = append(circuit.Constraints, *newConstr) + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *newConstr) nInputs++ - circuit.Signals = addToArrayIfNotExist(circuit.Signals, in) + circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, in) } - circuit.PublicInputs = constraint.PublicInputs - circuit.PrivateInputs = constraint.PrivateInputs + circuits[currCircuit].PublicInputs = constraint.PublicInputs + circuits[currCircuit].PrivateInputs = constraint.PrivateInputs continue } if constraint.Literal == "equals" { @@ -193,7 +255,7 @@ func (p *Parser) Parse() (*Circuit, error) { Out: constraint.V1, Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V1 + "==" + constraint.V2 + " * 1", } - circuit.Constraints = append(circuit.Constraints, *constr1) + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constr1) constr2 := &Constraint{ Op: "*", V1: constraint.V1, @@ -201,42 +263,64 @@ func (p *Parser) Parse() (*Circuit, error) { Out: constraint.V2, Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V2 + "==" + constraint.V1 + " * 1", } - circuit.Constraints = append(circuit.Constraints, *constr2) + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constr2) + continue + } + if constraint.Literal == "return" { + currCircuit = "" + continue + } + if constraint.Literal == "call" { + callsCountStr := strconv.Itoa(callsCount) + // for each of the constraints of the called circuit + // add it into the current circuit + signalMap := make(map[string]string) + for i, s := range constraint.PrivateInputs { + // signalMap[s] = circuits[constraint.Op].Constraints[0].PrivateInputs[i] + signalMap[circuits[constraint.Op].Constraints[0].PrivateInputs[i]+callsCountStr] = s + } + // add out to map + signalMap[circuits[constraint.Op].Constraints[len(circuits[constraint.Op].Constraints)-1].Out+callsCountStr] = constraint.Out + + for i := 1; i < len(circuits[constraint.Op].Constraints); i++ { + c := circuits[constraint.Op].Constraints[i] + // add constraint, puting unique names to vars + nc := &Constraint{ + Op: c.Op, + V1: subsIfInMap(c.V1+callsCountStr, signalMap), + V2: subsIfInMap(c.V2+callsCountStr, signalMap), + Out: subsIfInMap(c.Out+callsCountStr, signalMap), + Literal: "", + } + nc.Literal = nc.Out + "=" + nc.V1 + nc.Op + nc.V2 + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *nc) + } + for _, s := range circuits[constraint.Op].Signals { + circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, subsIfInMap(s+callsCountStr, signalMap)) + } + callsCount++ continue + } - circuit.Constraints = append(circuit.Constraints, *constraint) + + circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constraint) isVal, _ := isValue(constraint.V1) if !isVal { - circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1) + circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.V1) } isVal, _ = isValue(constraint.V2) if !isVal { - circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2) + circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.V2) } - // if constraint.Out == "out" { - // if Out is "out", put it after first value (one) and before the inputs - // if constraint.Out == circuit.PublicInputs[0] { - // if existInArray(circuit.PublicInputs, constraint.Out) { - // // if Out is a public signal, put it after first value (one) and before the private inputs - // if !existInArray(circuit.Signals, constraint.Out) { - // // if already don't exists in signal array - // signalsCopy := copyArray(circuit.Signals) - // var auxSignals []string - // auxSignals = append(auxSignals, signalsCopy[0]) - // auxSignals = append(auxSignals, constraint.Out) - // auxSignals = append(auxSignals, signalsCopy[1:]...) - // circuit.Signals = auxSignals - // // circuit.PublicInputs = append(circuit.PublicInputs, constraint.Out) - // circuit.NPublic++ - // } - // } else { - circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out) - // } + circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.Out) + } + circuits["main"].NVars = len(circuits["main"].Signals) + circuits["main"].NSignals = len(circuits["main"].Signals) + if mainExist == false { + return circuits["main"], errors.New("No 'main' func declared") } - circuit.NVars = len(circuit.Signals) - circuit.NSignals = len(circuit.Signals) - return circuit, nil + return circuits["main"], nil } func copyArray(in []string) []string { // tmp var out []string diff --git a/go-snark-cli b/go-snark-cli index cc5e06b..b503f67 100755 Binary files a/go-snark-cli and b/go-snark-cli differ diff --git a/snark_test.go b/snark_test.go index a1c070c..9fdfeae 100644 --- a/snark_test.go +++ b/snark_test.go @@ -18,20 +18,37 @@ func TestZkFromFlatCircuitCode(t *testing.T) { // circuit function // y = x^3 + x + 5 - flatCode := ` - func test(private s0, public s1): - s2 = s0 * s0 - s3 = s2 * s0 - s4 = s3 + s0 - s5 = s4 + 5 - equals(s1, s5) - out = 1 * 1 + code := ` + func exp3(private a): + b = a * a + c = a * b + return c + func sum(private a, private b): + c = a + b + return c + + func main(private s0, public s1): + s3 = exp3(s0) + s4 = sum(s3, s0) + s5 = s4 + 5 + equals(s1, s5) + out = 1 * 1 ` - fmt.Print("\nflat code of the circuit:") - fmt.Println(flatCode) + // the same code without the functions calling, all in one func + // code := ` + // func test(private s0, public s1): + // s2 = s0 * s0 + // s3 = s2 * s0 + // s4 = s3 + s0 + // s5 = s4 + 5 + // equals(s1, s5) + // out = 1 * 1 + // ` + fmt.Print("\ncode of the circuit:") + fmt.Println(code) // parse the code - parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) + parser := circuitcompiler.NewParser(strings.NewReader(code)) circuit, err := parser.Parse() assert.Nil(t, err) // fmt.Println("\ncircuit data:", circuit) @@ -47,8 +64,8 @@ func TestZkFromFlatCircuitCode(t *testing.T) { w, err := circuit.CalculateWitness(privateInputs, publicSignals) assert.Nil(t, err) - // flat code to R1CS - fmt.Println("\ngenerating R1CS from flat code") + // code to R1CS + fmt.Println("\ngenerating R1CS from code") a, b, c := circuit.GenerateR1CS() fmt.Println("\nR1CS:") fmt.Println("a:", a) @@ -132,16 +149,16 @@ func TestZkFromFlatCircuitCode(t *testing.T) { } func TestZkMultiplication(t *testing.T) { - flatCode := ` - func test(private a, private b, public c): + code := ` + func main(private a, private b, public c): d = a * b equals(c, d) out = 1 * 1 ` - fmt.Println("flat code", flatCode) + fmt.Println("code", code) // parse the code - parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) + parser := circuitcompiler.NewParser(strings.NewReader(code)) circuit, err := parser.Parse() assert.Nil(t, err) @@ -155,8 +172,8 @@ func TestZkMultiplication(t *testing.T) { w, err := circuit.CalculateWitness(privateInputs, publicSignals) assert.Nil(t, err) - // flat code to R1CS - fmt.Println("\ngenerating R1CS from flat code") + // code to R1CS + fmt.Println("\ngenerating R1CS from code") a, b, c := circuit.GenerateR1CS() fmt.Println("\nR1CS:") fmt.Println("a:", a) @@ -242,8 +259,8 @@ func TestZkMultiplication(t *testing.T) { func TestMinimalFlow(t *testing.T) { // circuit function // y = x^3 + x + 5 - flatCode := ` - func test(private s0, public s1): + code := ` + func main(private s0, public s1): s2 = s0 * s0 s3 = s2 * s0 s4 = s3 + s0 @@ -251,11 +268,11 @@ func TestMinimalFlow(t *testing.T) { equals(s1, s5) out = 1 * 1 ` - fmt.Print("\nflat code of the circuit:") - fmt.Println(flatCode) + fmt.Print("\ncode of the circuit:") + fmt.Println(code) // parse the code - parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) + parser := circuitcompiler.NewParser(strings.NewReader(code)) circuit, err := parser.Parse() assert.Nil(t, err) @@ -268,8 +285,8 @@ func TestMinimalFlow(t *testing.T) { w, err := circuit.CalculateWitness(privateInputs, publicSignals) assert.Nil(t, err) - // flat code to R1CS - fmt.Println("\ngenerating R1CS from flat code") + // code to R1CS + fmt.Println("\ngenerating R1CS from code") a, b, c := circuit.GenerateR1CS() fmt.Println("\nR1CS:") fmt.Println("a:", a)