Browse Source

circuitcompiler allow to call declared functions in circuits language

pull/10/head
arnaucube 5 years ago
parent
commit
165699b58f
5 changed files with 271 additions and 76 deletions
  1. +17
    -9
      README.md
  2. +87
    -1
      circuitcompiler/circuit_test.go
  3. +124
    -40
      circuitcompiler/parser.go
  4. BIN
      go-snark-cli
  5. +43
    -26
      snark_test.go

+ 17
- 9
README.md

@ -31,9 +31,9 @@ Minimal complete flow implementation:
- [x] verify proofs with BN128 pairing
Improvements from the minimal implementation:
- [x] allow to call functions in circuits language
- [ ] allow `import` in circuits language
- [ ] allow `for` in circuits language
- [ ] code to flat code (improve circuit compiler)
- [ ] move witness values calculation outside the setup phase
- [ ] Groth16
- [ ] multiple optimizations
@ -54,9 +54,13 @@ In this example we will follow the equation example from [Vitalik](https://mediu
#### Compile circuit
Having a circuit file `test.circuit`:
```
func test(private s0, public s1):
s2 = s0 * s0
s3 = s2 * s0
func exp3(private a):
b = a * a
c = a * b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = s3 + s0
s5 = s4 + 5
equals(s1, s5)
@ -112,9 +116,13 @@ Example:
```go
// compile circuit and get the R1CS
flatCode := `
func test(private s0, public s1):
s2 = s0 * s0
s3 = s2 * s0
func exp3(private a):
b = a * a
c = a * b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = s3 + s0
s5 = s4 + 5
equals(s1, s5)
@ -147,8 +155,8 @@ a, b, c := circuit.GenerateR1CS()
/*
now we have the R1CS from the circuit:
a: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]]
b: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]]
a: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]]
b: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]]
c: [[0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]]
*/

+ 87
- 1
circuitcompiler/circuit_test.go

@ -11,7 +11,7 @@ import (
func TestCircuitParser(t *testing.T) {
// y = x^3 + x + 5
flat := `
func test(private s0, public s1):
func main(private s0, public s1):
s2 = s0 * s0
s3 = s2 * s0
s4 = s3 + s0
@ -86,3 +86,89 @@ func TestCircuitParser(t *testing.T) {
assert.Equal(t, len(circuit.PublicInputs), 1)
assert.Equal(t, len(circuit.PrivateInputs), 1)
}
func TestCircuitWithFuncCallsParser(t *testing.T) {
// y = x^3 + x + 5
code := `
func exp3(private a):
b = a * a
c = a * b
return c
func sum(private a, private b):
c = a + b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = sum(s3, s0)
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1
`
parser := NewParser(strings.NewReader(code))
circuit, err := parser.Parse()
assert.Nil(t, err)
// flat code to R1CS
a, b, c := circuit.GenerateR1CS()
assert.Equal(t, "s0", circuit.PrivateInputs[0])
assert.Equal(t, "s1", circuit.PublicInputs[0])
assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, circuit.Signals)
// expected result
b0 := big.NewInt(int64(0))
b1 := big.NewInt(int64(1))
b5 := big.NewInt(int64(5))
aExpected := [][]*big.Int{
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
[]*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
}
bExpected := [][]*big.Int{
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
}
cExpected := [][]*big.Int{
[]*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
}
assert.Equal(t, aExpected, a)
assert.Equal(t, bExpected, b)
assert.Equal(t, cExpected, c)
b3 := big.NewInt(int64(3))
privateInputs := []*big.Int{b3}
b35 := big.NewInt(int64(35))
publicInputs := []*big.Int{b35}
// Calculate Witness
w, err := circuit.CalculateWitness(privateInputs, publicInputs)
assert.Nil(t, err)
b9 := big.NewInt(int64(9))
b27 := big.NewInt(int64(27))
b30 := big.NewInt(int64(30))
wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
assert.Equal(t, wExpected, w)
// circuitJson, _ := json.Marshal(circuit)
// fmt.Println("circuit:", string(circuitJson))
assert.Equal(t, circuit.NPublic, 1)
assert.Equal(t, len(circuit.PublicInputs), 1)
assert.Equal(t, len(circuit.PrivateInputs), 1)
}

+ 124
- 40
circuitcompiler/parser.go

@ -6,6 +6,7 @@ import (
"io"
"os"
"regexp"
"strconv"
"strings"
)
@ -68,6 +69,12 @@ func (p *Parser) parseLine() (*Constraint, error) {
if err != nil {
return c, err
}
// get func name
fName := strings.Split(line, "(")[0]
fName = strings.Replace(fName, " ", "", -1)
fName = strings.Replace(fName, " ", "", -1)
c.V1 = fName // so, the name of the func will be in c.V1
// read string inside ( )
rgx := regexp.MustCompile(`\((.*?)\)`)
insideParenthesis := rgx.FindStringSubmatch(line)
@ -105,20 +112,47 @@ func (p *Parser) parseLine() (*Constraint, error) {
c.V2 = params[1]
return c, nil
}
// if c.Literal == "out" {
// // TODO
// return c, nil
// }
if c.Literal == "return" {
_, varToReturn := p.scanIgnoreWhitespace()
c.Out = varToReturn
return c, nil
}
_, lit = p.scanIgnoreWhitespace() // skip =
c.Literal += lit
// v1
_, lit = p.scanIgnoreWhitespace()
// check if lit is a name of a func that we have declared
if _, ok := circuits[lit]; ok {
// if inside, is calling a declared function
c.Literal = "call"
c.Op = lit // c.Op handles the name of the function called
// put the inputs of the call into the c.PrivateInputs
// format: `funcname(a, b)`
line, err := p.s.r.ReadString(')')
if err != nil {
fmt.Println("ERR", err)
return c, err
}
// read string inside ( )
rgx := regexp.MustCompile(`\((.*?)\)`)
insideParenthesis := rgx.FindStringSubmatch(line)
varsString := strings.Replace(insideParenthesis[1], " ", "", -1)
params := strings.Split(varsString, ",")
c.PrivateInputs = params
return c, nil
}
c.V1 = lit
c.Literal += lit
// operator
_, lit = p.scanIgnoreWhitespace()
if lit == "(" {
panic(errors.New("using not declared function"))
}
c.Op = lit
c.Literal += lit
// v2
@ -150,39 +184,67 @@ func addToArrayIfNotExist(arr []string, elem string) []string {
return arr
}
func subsIfInMap(original string, m map[string]string) string {
if v, ok := m[original]; ok {
return v
}
return original
}
var circuits map[string]*Circuit
// Parse parses the lines and returns the compiled Circuit
func (p *Parser) Parse() (*Circuit, error) {
circuit := &Circuit{}
circuit.Signals = append(circuit.Signals, "one")
// funcsMap is a map holding the functions names and it's content as Circuit
circuits = make(map[string]*Circuit)
mainExist := false
circuits["main"] = &Circuit{}
callsCount := 0
circuits["main"].Signals = append(circuits["main"].Signals, "one")
nInputs := 0
currCircuit := ""
for {
constraint, err := p.parseLine()
if err != nil {
break
}
if constraint.Literal == "func" {
// the name of the func is in constraint.V1
// check if the name of func is main
if constraint.V1 != "main" {
currCircuit = constraint.V1
circuits[currCircuit] = &Circuit{}
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constraint)
continue
}
currCircuit = "main"
mainExist = true
// l, _ := json.Marshal(constraint)
// fmt.Println(string(l))
// one constraint for each input
for _, in := range constraint.PublicInputs {
newConstr := &Constraint{
Op: "in",
Out: in,
}
circuit.Constraints = append(circuit.Constraints, *newConstr)
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *newConstr)
nInputs++
circuit.Signals = addToArrayIfNotExist(circuit.Signals, in)
circuit.NPublic++
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, in)
circuits[currCircuit].NPublic++
}
for _, in := range constraint.PrivateInputs {
newConstr := &Constraint{
Op: "in",
Out: in,
}
circuit.Constraints = append(circuit.Constraints, *newConstr)
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *newConstr)
nInputs++
circuit.Signals = addToArrayIfNotExist(circuit.Signals, in)
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, in)
}
circuit.PublicInputs = constraint.PublicInputs
circuit.PrivateInputs = constraint.PrivateInputs
circuits[currCircuit].PublicInputs = constraint.PublicInputs
circuits[currCircuit].PrivateInputs = constraint.PrivateInputs
continue
}
if constraint.Literal == "equals" {
@ -193,7 +255,7 @@ func (p *Parser) Parse() (*Circuit, error) {
Out: constraint.V1,
Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V1 + "==" + constraint.V2 + " * 1",
}
circuit.Constraints = append(circuit.Constraints, *constr1)
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constr1)
constr2 := &Constraint{
Op: "*",
V1: constraint.V1,
@ -201,42 +263,64 @@ func (p *Parser) Parse() (*Circuit, error) {
Out: constraint.V2,
Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V2 + "==" + constraint.V1 + " * 1",
}
circuit.Constraints = append(circuit.Constraints, *constr2)
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constr2)
continue
}
if constraint.Literal == "return" {
currCircuit = ""
continue
}
if constraint.Literal == "call" {
callsCountStr := strconv.Itoa(callsCount)
// for each of the constraints of the called circuit
// add it into the current circuit
signalMap := make(map[string]string)
for i, s := range constraint.PrivateInputs {
// signalMap[s] = circuits[constraint.Op].Constraints[0].PrivateInputs[i]
signalMap[circuits[constraint.Op].Constraints[0].PrivateInputs[i]+callsCountStr] = s
}
// add out to map
signalMap[circuits[constraint.Op].Constraints[len(circuits[constraint.Op].Constraints)-1].Out+callsCountStr] = constraint.Out
for i := 1; i < len(circuits[constraint.Op].Constraints); i++ {
c := circuits[constraint.Op].Constraints[i]
// add constraint, puting unique names to vars
nc := &Constraint{
Op: c.Op,
V1: subsIfInMap(c.V1+callsCountStr, signalMap),
V2: subsIfInMap(c.V2+callsCountStr, signalMap),
Out: subsIfInMap(c.Out+callsCountStr, signalMap),
Literal: "",
}
nc.Literal = nc.Out + "=" + nc.V1 + nc.Op + nc.V2
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *nc)
}
for _, s := range circuits[constraint.Op].Signals {
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, subsIfInMap(s+callsCountStr, signalMap))
}
callsCount++
continue
}
circuit.Constraints = append(circuit.Constraints, *constraint)
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constraint)
isVal, _ := isValue(constraint.V1)
if !isVal {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.V1)
}
isVal, _ = isValue(constraint.V2)
if !isVal {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.V2)
}
// if constraint.Out == "out" {
// if Out is "out", put it after first value (one) and before the inputs
// if constraint.Out == circuit.PublicInputs[0] {
// if existInArray(circuit.PublicInputs, constraint.Out) {
// // if Out is a public signal, put it after first value (one) and before the private inputs
// if !existInArray(circuit.Signals, constraint.Out) {
// // if already don't exists in signal array
// signalsCopy := copyArray(circuit.Signals)
// var auxSignals []string
// auxSignals = append(auxSignals, signalsCopy[0])
// auxSignals = append(auxSignals, constraint.Out)
// auxSignals = append(auxSignals, signalsCopy[1:]...)
// circuit.Signals = auxSignals
// // circuit.PublicInputs = append(circuit.PublicInputs, constraint.Out)
// circuit.NPublic++
// }
// } else {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
// }
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.Out)
}
circuits["main"].NVars = len(circuits["main"].Signals)
circuits["main"].NSignals = len(circuits["main"].Signals)
if mainExist == false {
return circuits["main"], errors.New("No 'main' func declared")
}
circuit.NVars = len(circuit.Signals)
circuit.NSignals = len(circuit.Signals)
return circuit, nil
return circuits["main"], nil
}
func copyArray(in []string) []string { // tmp
var out []string

BIN
go-snark-cli


+ 43
- 26
snark_test.go

@ -18,20 +18,37 @@ func TestZkFromFlatCircuitCode(t *testing.T) {
// circuit function
// y = x^3 + x + 5
flatCode := `
func test(private s0, public s1):
s2 = s0 * s0
s3 = s2 * s0
s4 = s3 + s0
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1
code := `
func exp3(private a):
b = a * a
c = a * b
return c
func sum(private a, private b):
c = a + b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = sum(s3, s0)
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1
`
fmt.Print("\nflat code of the circuit:")
fmt.Println(flatCode)
// the same code without the functions calling, all in one func
// code := `
// func test(private s0, public s1):
// s2 = s0 * s0
// s3 = s2 * s0
// s4 = s3 + s0
// s5 = s4 + 5
// equals(s1, s5)
// out = 1 * 1
// `
fmt.Print("\ncode of the circuit:")
fmt.Println(code)
// parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
parser := circuitcompiler.NewParser(strings.NewReader(code))
circuit, err := parser.Parse()
assert.Nil(t, err)
// fmt.Println("\ncircuit data:", circuit)
@ -47,8 +64,8 @@ func TestZkFromFlatCircuitCode(t *testing.T) {
w, err := circuit.CalculateWitness(privateInputs, publicSignals)
assert.Nil(t, err)
// flat code to R1CS
fmt.Println("\ngenerating R1CS from flat code")
// code to R1CS
fmt.Println("\ngenerating R1CS from code")
a, b, c := circuit.GenerateR1CS()
fmt.Println("\nR1CS:")
fmt.Println("a:", a)
@ -132,16 +149,16 @@ func TestZkFromFlatCircuitCode(t *testing.T) {
}
func TestZkMultiplication(t *testing.T) {
flatCode := `
func test(private a, private b, public c):
code := `
func main(private a, private b, public c):
d = a * b
equals(c, d)
out = 1 * 1
`
fmt.Println("flat code", flatCode)
fmt.Println("code", code)
// parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
parser := circuitcompiler.NewParser(strings.NewReader(code))
circuit, err := parser.Parse()
assert.Nil(t, err)
@ -155,8 +172,8 @@ func TestZkMultiplication(t *testing.T) {
w, err := circuit.CalculateWitness(privateInputs, publicSignals)
assert.Nil(t, err)
// flat code to R1CS
fmt.Println("\ngenerating R1CS from flat code")
// code to R1CS
fmt.Println("\ngenerating R1CS from code")
a, b, c := circuit.GenerateR1CS()
fmt.Println("\nR1CS:")
fmt.Println("a:", a)
@ -242,8 +259,8 @@ func TestZkMultiplication(t *testing.T) {
func TestMinimalFlow(t *testing.T) {
// circuit function
// y = x^3 + x + 5
flatCode := `
func test(private s0, public s1):
code := `
func main(private s0, public s1):
s2 = s0 * s0
s3 = s2 * s0
s4 = s3 + s0
@ -251,11 +268,11 @@ func TestMinimalFlow(t *testing.T) {
equals(s1, s5)
out = 1 * 1
`
fmt.Print("\nflat code of the circuit:")
fmt.Println(flatCode)
fmt.Print("\ncode of the circuit:")
fmt.Println(code)
// parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
parser := circuitcompiler.NewParser(strings.NewReader(code))
circuit, err := parser.Parse()
assert.Nil(t, err)
@ -268,8 +285,8 @@ func TestMinimalFlow(t *testing.T) {
w, err := circuit.CalculateWitness(privateInputs, publicSignals)
assert.Nil(t, err)
// flat code to R1CS
fmt.Println("\ngenerating R1CS from flat code")
// code to R1CS
fmt.Println("\ngenerating R1CS from code")
a, b, c := circuit.GenerateR1CS()
fmt.Println("\nR1CS:")
fmt.Println("a:", a)

Loading…
Cancel
Save