2 Commits
0.0.1 ... 0.0.2

Author SHA1 Message Date
arnaucube
de5b60b826 add allow import circuits in circuits language compiler 2019-05-30 21:39:00 +02:00
arnaucube
165699b58f circuitcompiler allow to call declared functions in circuits language 2019-05-25 04:11:39 +02:00
10 changed files with 403 additions and 77 deletions

View File

@@ -31,9 +31,9 @@ Minimal complete flow implementation:
- [x] verify proofs with BN128 pairing - [x] verify proofs with BN128 pairing
Improvements from the minimal implementation: Improvements from the minimal implementation:
- [ ] allow `import` in circuits language - [x] allow to call functions in circuits language
- [x] allow `import` in circuits language
- [ ] allow `for` in circuits language - [ ] allow `for` in circuits language
- [ ] code to flat code (improve circuit compiler)
- [ ] move witness values calculation outside the setup phase - [ ] move witness values calculation outside the setup phase
- [ ] Groth16 - [ ] Groth16
- [ ] multiple optimizations - [ ] multiple optimizations
@@ -54,9 +54,13 @@ In this example we will follow the equation example from [Vitalik](https://mediu
#### Compile circuit #### Compile circuit
Having a circuit file `test.circuit`: Having a circuit file `test.circuit`:
``` ```
func test(private s0, public s1): func exp3(private a):
s2 = s0 * s0 b = a * a
s3 = s2 * s0 c = a * b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = s3 + s0 s4 = s3 + s0
s5 = s4 + 5 s5 = s4 + 5
equals(s1, s5) equals(s1, s5)
@@ -112,9 +116,13 @@ Example:
```go ```go
// compile circuit and get the R1CS // compile circuit and get the R1CS
flatCode := ` flatCode := `
func test(private s0, public s1): func exp3(private a):
s2 = s0 * s0 b = a * a
s3 = s2 * s0 c = a * b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = s3 + s0 s4 = s3 + s0
s5 = s4 + 5 s5 = s4 + 5
equals(s1, s5) equals(s1, s5)
@@ -147,8 +155,8 @@ a, b, c := circuit.GenerateR1CS()
/* /*
now we have the R1CS from the circuit: now we have the R1CS from the circuit:
a: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] a: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [0 0 1 0 1 0 0 0] [5 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]]
b: [[0 0 1 0 0 0 0 0] [0 0 1 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]] b: [[0 0 1 0 0 0 0 0] [0 0 0 1 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]]
c: [[0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]] c: [[0 0 0 1 0 0 0 0] [0 0 0 0 1 0 0 0] [0 0 0 0 0 1 0 0] [0 0 0 0 0 0 1 0] [0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 1]]
*/ */

View File

@@ -0,0 +1,8 @@
import "circuit-test-2.circuit"
func main(private s0, public s1):
s3 = exp3(s0)
s4 = sum(s3, s0)
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1

View File

@@ -0,0 +1,7 @@
func exp3(private a):
b = a * a
c = a * b
return c
func sum(private a, private b):
c = a + b
return c

View File

@@ -1,7 +1,9 @@
package circuitcompiler package circuitcompiler
import ( import (
"bufio"
"math/big" "math/big"
"os"
"strings" "strings"
"testing" "testing"
@@ -11,7 +13,7 @@ import (
func TestCircuitParser(t *testing.T) { func TestCircuitParser(t *testing.T) {
// y = x^3 + x + 5 // y = x^3 + x + 5
flat := ` flat := `
func test(private s0, public s1): func main(private s0, public s1):
s2 = s0 * s0 s2 = s0 * s0
s3 = s2 * s0 s3 = s2 * s0
s4 = s3 + s0 s4 = s3 + s0
@@ -86,3 +88,161 @@ func TestCircuitParser(t *testing.T) {
assert.Equal(t, len(circuit.PublicInputs), 1) assert.Equal(t, len(circuit.PublicInputs), 1)
assert.Equal(t, len(circuit.PrivateInputs), 1) assert.Equal(t, len(circuit.PrivateInputs), 1)
} }
func TestCircuitWithFuncCallsParser(t *testing.T) {
// y = x^3 + x + 5
code := `
func exp3(private a):
b = a * a
c = a * b
return c
func sum(private a, private b):
c = a + b
return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = sum(s3, s0)
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1
`
parser := NewParser(strings.NewReader(code))
circuit, err := parser.Parse()
assert.Nil(t, err)
// flat code to R1CS
a, b, c := circuit.GenerateR1CS()
assert.Equal(t, "s0", circuit.PrivateInputs[0])
assert.Equal(t, "s1", circuit.PublicInputs[0])
assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, circuit.Signals)
// expected result
b0 := big.NewInt(int64(0))
b1 := big.NewInt(int64(1))
b5 := big.NewInt(int64(5))
aExpected := [][]*big.Int{
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
[]*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
}
bExpected := [][]*big.Int{
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
}
cExpected := [][]*big.Int{
[]*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
}
assert.Equal(t, aExpected, a)
assert.Equal(t, bExpected, b)
assert.Equal(t, cExpected, c)
b3 := big.NewInt(int64(3))
privateInputs := []*big.Int{b3}
b35 := big.NewInt(int64(35))
publicInputs := []*big.Int{b35}
// Calculate Witness
w, err := circuit.CalculateWitness(privateInputs, publicInputs)
assert.Nil(t, err)
b9 := big.NewInt(int64(9))
b27 := big.NewInt(int64(27))
b30 := big.NewInt(int64(30))
wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
assert.Equal(t, wExpected, w)
// circuitJson, _ := json.Marshal(circuit)
// fmt.Println("circuit:", string(circuitJson))
assert.Equal(t, circuit.NPublic, 1)
assert.Equal(t, len(circuit.PublicInputs), 1)
assert.Equal(t, len(circuit.PrivateInputs), 1)
}
func TestCircuitFromFileWithImports(t *testing.T) {
circuitFile, err := os.Open("./circuit-test-1.circuit")
assert.Nil(t, err)
parser := NewParser(bufio.NewReader(circuitFile))
circuit, err := parser.Parse()
assert.Nil(t, err)
// flat code to R1CS
a, b, c := circuit.GenerateR1CS()
assert.Equal(t, "s0", circuit.PrivateInputs[0])
assert.Equal(t, "s1", circuit.PublicInputs[0])
assert.Equal(t, []string{"one", "s1", "s0", "b0", "s3", "s4", "s5", "out"}, circuit.Signals)
// expected result
b0 := big.NewInt(int64(0))
b1 := big.NewInt(int64(1))
b5 := big.NewInt(int64(5))
aExpected := [][]*big.Int{
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b1, b0, b1, b0, b0, b0},
[]*big.Int{b5, b0, b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
}
bExpected := [][]*big.Int{
[]*big.Int{b0, b0, b1, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0, b0, b0},
}
cExpected := [][]*big.Int{
[]*big.Int{b0, b0, b0, b1, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b1, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b0, b0, b0, b0, b0, b0, b1},
}
assert.Equal(t, aExpected, a)
assert.Equal(t, bExpected, b)
assert.Equal(t, cExpected, c)
b3 := big.NewInt(int64(3))
privateInputs := []*big.Int{b3}
b35 := big.NewInt(int64(35))
publicInputs := []*big.Int{b35}
// Calculate Witness
w, err := circuit.CalculateWitness(privateInputs, publicInputs)
assert.Nil(t, err)
b9 := big.NewInt(int64(9))
b27 := big.NewInt(int64(27))
b30 := big.NewInt(int64(30))
wExpected := []*big.Int{b1, b35, b3, b9, b27, b30, b35, b1}
assert.Equal(t, wExpected, w)
// circuitJson, _ := json.Marshal(circuit)
// fmt.Println("circuit:", string(circuitJson))
assert.Equal(t, circuit.NPublic, 1)
assert.Equal(t, len(circuit.PublicInputs), 1)
assert.Equal(t, len(circuit.PrivateInputs), 1)
}

View File

@@ -1,11 +1,13 @@
package circuitcompiler package circuitcompiler
import ( import (
"bufio"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"regexp" "regexp"
"strconv"
"strings" "strings"
) )
@@ -68,6 +70,12 @@ func (p *Parser) parseLine() (*Constraint, error) {
if err != nil { if err != nil {
return c, err return c, err
} }
// get func name
fName := strings.Split(line, "(")[0]
fName = strings.Replace(fName, " ", "", -1)
fName = strings.Replace(fName, " ", "", -1)
c.V1 = fName // so, the name of the func will be in c.V1
// read string inside ( ) // read string inside ( )
rgx := regexp.MustCompile(`\((.*?)\)`) rgx := regexp.MustCompile(`\((.*?)\)`)
insideParenthesis := rgx.FindStringSubmatch(line) insideParenthesis := rgx.FindStringSubmatch(line)
@@ -105,20 +113,60 @@ func (p *Parser) parseLine() (*Constraint, error) {
c.V2 = params[1] c.V2 = params[1]
return c, nil return c, nil
} }
// if c.Literal == "out" { if c.Literal == "return" {
// // TODO _, varToReturn := p.scanIgnoreWhitespace()
// return c, nil c.Out = varToReturn
// } return c, nil
}
if c.Literal == "import" {
line, err := p.s.r.ReadString('\n')
if err != nil {
return c, err
}
// read string inside " "
path := strings.TrimLeft(strings.TrimRight(line, `"`), `"`)
path = strings.Replace(path, `"`, "", -1)
path = strings.Replace(path, " ", "", -1)
path = strings.Replace(path, "\n", "", -1)
c.Out = path
return c, nil
}
_, lit = p.scanIgnoreWhitespace() // skip = _, lit = p.scanIgnoreWhitespace() // skip =
c.Literal += lit c.Literal += lit
// v1 // v1
_, lit = p.scanIgnoreWhitespace() _, lit = p.scanIgnoreWhitespace()
// check if lit is a name of a func that we have declared
if _, ok := circuits[lit]; ok {
// if inside, is calling a declared function
c.Literal = "call"
c.Op = lit // c.Op handles the name of the function called
// put the inputs of the call into the c.PrivateInputs
// format: `funcname(a, b)`
line, err := p.s.r.ReadString(')')
if err != nil {
fmt.Println("ERR", err)
return c, err
}
// read string inside ( )
rgx := regexp.MustCompile(`\((.*?)\)`)
insideParenthesis := rgx.FindStringSubmatch(line)
varsString := strings.Replace(insideParenthesis[1], " ", "", -1)
params := strings.Split(varsString, ",")
c.PrivateInputs = params
return c, nil
}
c.V1 = lit c.V1 = lit
c.Literal += lit c.Literal += lit
// operator // operator
_, lit = p.scanIgnoreWhitespace() _, lit = p.scanIgnoreWhitespace()
if lit == "(" {
panic(errors.New("using not declared function"))
}
c.Op = lit c.Op = lit
c.Literal += lit c.Literal += lit
// v2 // v2
@@ -150,39 +198,67 @@ func addToArrayIfNotExist(arr []string, elem string) []string {
return arr return arr
} }
func subsIfInMap(original string, m map[string]string) string {
if v, ok := m[original]; ok {
return v
}
return original
}
var circuits map[string]*Circuit
// Parse parses the lines and returns the compiled Circuit // Parse parses the lines and returns the compiled Circuit
func (p *Parser) Parse() (*Circuit, error) { func (p *Parser) Parse() (*Circuit, error) {
circuit := &Circuit{} // funcsMap is a map holding the functions names and it's content as Circuit
circuit.Signals = append(circuit.Signals, "one") circuits = make(map[string]*Circuit)
mainExist := false
circuits["main"] = &Circuit{}
callsCount := 0
circuits["main"].Signals = append(circuits["main"].Signals, "one")
nInputs := 0 nInputs := 0
currCircuit := ""
for { for {
constraint, err := p.parseLine() constraint, err := p.parseLine()
if err != nil { if err != nil {
break break
} }
if constraint.Literal == "func" { if constraint.Literal == "func" {
// the name of the func is in constraint.V1
// check if the name of func is main
if constraint.V1 != "main" {
currCircuit = constraint.V1
circuits[currCircuit] = &Circuit{}
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constraint)
continue
}
currCircuit = "main"
mainExist = true
// l, _ := json.Marshal(constraint)
// fmt.Println(string(l))
// one constraint for each input // one constraint for each input
for _, in := range constraint.PublicInputs { for _, in := range constraint.PublicInputs {
newConstr := &Constraint{ newConstr := &Constraint{
Op: "in", Op: "in",
Out: in, Out: in,
} }
circuit.Constraints = append(circuit.Constraints, *newConstr) circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *newConstr)
nInputs++ nInputs++
circuit.Signals = addToArrayIfNotExist(circuit.Signals, in) circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, in)
circuit.NPublic++ circuits[currCircuit].NPublic++
} }
for _, in := range constraint.PrivateInputs { for _, in := range constraint.PrivateInputs {
newConstr := &Constraint{ newConstr := &Constraint{
Op: "in", Op: "in",
Out: in, Out: in,
} }
circuit.Constraints = append(circuit.Constraints, *newConstr) circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *newConstr)
nInputs++ nInputs++
circuit.Signals = addToArrayIfNotExist(circuit.Signals, in) circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, in)
} }
circuit.PublicInputs = constraint.PublicInputs circuits[currCircuit].PublicInputs = constraint.PublicInputs
circuit.PrivateInputs = constraint.PrivateInputs circuits[currCircuit].PrivateInputs = constraint.PrivateInputs
continue continue
} }
if constraint.Literal == "equals" { if constraint.Literal == "equals" {
@@ -193,7 +269,7 @@ func (p *Parser) Parse() (*Circuit, error) {
Out: constraint.V1, Out: constraint.V1,
Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V1 + "==" + constraint.V2 + " * 1", Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V1 + "==" + constraint.V2 + " * 1",
} }
circuit.Constraints = append(circuit.Constraints, *constr1) circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constr1)
constr2 := &Constraint{ constr2 := &Constraint{
Op: "*", Op: "*",
V1: constraint.V1, V1: constraint.V1,
@@ -201,42 +277,73 @@ func (p *Parser) Parse() (*Circuit, error) {
Out: constraint.V2, Out: constraint.V2,
Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V2 + "==" + constraint.V1 + " * 1", Literal: "equals(" + constraint.V1 + ", " + constraint.V2 + "): " + constraint.V2 + "==" + constraint.V1 + " * 1",
} }
circuit.Constraints = append(circuit.Constraints, *constr2) circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constr2)
continue continue
} }
circuit.Constraints = append(circuit.Constraints, *constraint) if constraint.Literal == "return" {
currCircuit = ""
continue
}
if constraint.Literal == "call" {
callsCountStr := strconv.Itoa(callsCount)
// for each of the constraints of the called circuit
// add it into the current circuit
signalMap := make(map[string]string)
for i, s := range constraint.PrivateInputs {
// signalMap[s] = circuits[constraint.Op].Constraints[0].PrivateInputs[i]
signalMap[circuits[constraint.Op].Constraints[0].PrivateInputs[i]+callsCountStr] = s
}
// add out to map
signalMap[circuits[constraint.Op].Constraints[len(circuits[constraint.Op].Constraints)-1].Out+callsCountStr] = constraint.Out
for i := 1; i < len(circuits[constraint.Op].Constraints); i++ {
c := circuits[constraint.Op].Constraints[i]
// add constraint, puting unique names to vars
nc := &Constraint{
Op: c.Op,
V1: subsIfInMap(c.V1+callsCountStr, signalMap),
V2: subsIfInMap(c.V2+callsCountStr, signalMap),
Out: subsIfInMap(c.Out+callsCountStr, signalMap),
Literal: "",
}
nc.Literal = nc.Out + "=" + nc.V1 + nc.Op + nc.V2
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *nc)
}
for _, s := range circuits[constraint.Op].Signals {
circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, subsIfInMap(s+callsCountStr, signalMap))
}
callsCount++
continue
}
if constraint.Literal == "import" {
circuitFile, err := os.Open(constraint.Out)
if err != nil {
panic(errors.New("imported path error: " + constraint.Out))
}
parser := NewParser(bufio.NewReader(circuitFile))
_, err = parser.Parse() // this will add the imported file funcs into the `circuits` map
continue
}
circuits[currCircuit].Constraints = append(circuits[currCircuit].Constraints, *constraint)
isVal, _ := isValue(constraint.V1) isVal, _ := isValue(constraint.V1)
if !isVal { if !isVal {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1) circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.V1)
} }
isVal, _ = isValue(constraint.V2) isVal, _ = isValue(constraint.V2)
if !isVal { if !isVal {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2) circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.V2)
} }
// if constraint.Out == "out" { circuits[currCircuit].Signals = addToArrayIfNotExist(circuits[currCircuit].Signals, constraint.Out)
// if Out is "out", put it after first value (one) and before the inputs
// if constraint.Out == circuit.PublicInputs[0] {
// if existInArray(circuit.PublicInputs, constraint.Out) {
// // if Out is a public signal, put it after first value (one) and before the private inputs
// if !existInArray(circuit.Signals, constraint.Out) {
// // if already don't exists in signal array
// signalsCopy := copyArray(circuit.Signals)
// var auxSignals []string
// auxSignals = append(auxSignals, signalsCopy[0])
// auxSignals = append(auxSignals, constraint.Out)
// auxSignals = append(auxSignals, signalsCopy[1:]...)
// circuit.Signals = auxSignals
// // circuit.PublicInputs = append(circuit.PublicInputs, constraint.Out)
// circuit.NPublic++
// }
// } else {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
// }
} }
circuit.NVars = len(circuit.Signals) circuits["main"].NVars = len(circuits["main"].Signals)
circuit.NSignals = len(circuit.Signals) circuits["main"].NSignals = len(circuits["main"].Signals)
return circuit, nil if mainExist == false {
return circuits["main"], errors.New("No 'main' func declared")
}
return circuits["main"], nil
} }
func copyArray(in []string) []string { // tmp func copyArray(in []string) []string { // tmp
var out []string var out []string

View File

@@ -0,0 +1,8 @@
import "imported-example.circuit"
func main(private s0, public s1):
s3 = exp3(s0)
s4 = sum(s3, s0)
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1

View File

@@ -0,0 +1,7 @@
func exp3(private a):
b = a * a
c = a * b
return c
func sum(private a, private b):
c = a + b
return c

Binary file not shown.

View File

@@ -18,20 +18,37 @@ func TestZkFromFlatCircuitCode(t *testing.T) {
// circuit function // circuit function
// y = x^3 + x + 5 // y = x^3 + x + 5
flatCode := ` code := `
func test(private s0, public s1): func exp3(private a):
s2 = s0 * s0 b = a * a
s3 = s2 * s0 c = a * b
s4 = s3 + s0 return c
s5 = s4 + 5 func sum(private a, private b):
equals(s1, s5) c = a + b
out = 1 * 1 return c
func main(private s0, public s1):
s3 = exp3(s0)
s4 = sum(s3, s0)
s5 = s4 + 5
equals(s1, s5)
out = 1 * 1
` `
fmt.Print("\nflat code of the circuit:") // the same code without the functions calling, all in one func
fmt.Println(flatCode) // code := `
// func test(private s0, public s1):
// s2 = s0 * s0
// s3 = s2 * s0
// s4 = s3 + s0
// s5 = s4 + 5
// equals(s1, s5)
// out = 1 * 1
// `
fmt.Print("\ncode of the circuit:")
fmt.Println(code)
// parse the code // parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) parser := circuitcompiler.NewParser(strings.NewReader(code))
circuit, err := parser.Parse() circuit, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
// fmt.Println("\ncircuit data:", circuit) // fmt.Println("\ncircuit data:", circuit)
@@ -47,8 +64,8 @@ func TestZkFromFlatCircuitCode(t *testing.T) {
w, err := circuit.CalculateWitness(privateInputs, publicSignals) w, err := circuit.CalculateWitness(privateInputs, publicSignals)
assert.Nil(t, err) assert.Nil(t, err)
// flat code to R1CS // code to R1CS
fmt.Println("\ngenerating R1CS from flat code") fmt.Println("\ngenerating R1CS from code")
a, b, c := circuit.GenerateR1CS() a, b, c := circuit.GenerateR1CS()
fmt.Println("\nR1CS:") fmt.Println("\nR1CS:")
fmt.Println("a:", a) fmt.Println("a:", a)
@@ -132,16 +149,16 @@ func TestZkFromFlatCircuitCode(t *testing.T) {
} }
func TestZkMultiplication(t *testing.T) { func TestZkMultiplication(t *testing.T) {
flatCode := ` code := `
func test(private a, private b, public c): func main(private a, private b, public c):
d = a * b d = a * b
equals(c, d) equals(c, d)
out = 1 * 1 out = 1 * 1
` `
fmt.Println("flat code", flatCode) fmt.Println("code", code)
// parse the code // parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) parser := circuitcompiler.NewParser(strings.NewReader(code))
circuit, err := parser.Parse() circuit, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
@@ -155,8 +172,8 @@ func TestZkMultiplication(t *testing.T) {
w, err := circuit.CalculateWitness(privateInputs, publicSignals) w, err := circuit.CalculateWitness(privateInputs, publicSignals)
assert.Nil(t, err) assert.Nil(t, err)
// flat code to R1CS // code to R1CS
fmt.Println("\ngenerating R1CS from flat code") fmt.Println("\ngenerating R1CS from code")
a, b, c := circuit.GenerateR1CS() a, b, c := circuit.GenerateR1CS()
fmt.Println("\nR1CS:") fmt.Println("\nR1CS:")
fmt.Println("a:", a) fmt.Println("a:", a)
@@ -242,8 +259,8 @@ func TestZkMultiplication(t *testing.T) {
func TestMinimalFlow(t *testing.T) { func TestMinimalFlow(t *testing.T) {
// circuit function // circuit function
// y = x^3 + x + 5 // y = x^3 + x + 5
flatCode := ` code := `
func test(private s0, public s1): func main(private s0, public s1):
s2 = s0 * s0 s2 = s0 * s0
s3 = s2 * s0 s3 = s2 * s0
s4 = s3 + s0 s4 = s3 + s0
@@ -251,11 +268,11 @@ func TestMinimalFlow(t *testing.T) {
equals(s1, s5) equals(s1, s5)
out = 1 * 1 out = 1 * 1
` `
fmt.Print("\nflat code of the circuit:") fmt.Print("\ncode of the circuit:")
fmt.Println(flatCode) fmt.Println(code)
// parse the code // parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) parser := circuitcompiler.NewParser(strings.NewReader(code))
circuit, err := parser.Parse() circuit, err := parser.Parse()
assert.Nil(t, err) assert.Nil(t, err)
@@ -268,8 +285,8 @@ func TestMinimalFlow(t *testing.T) {
w, err := circuit.CalculateWitness(privateInputs, publicSignals) w, err := circuit.CalculateWitness(privateInputs, publicSignals)
assert.Nil(t, err) assert.Nil(t, err)
// flat code to R1CS // code to R1CS
fmt.Println("\ngenerating R1CS from flat code") fmt.Println("\ngenerating R1CS from code")
a, b, c := circuit.GenerateR1CS() a, b, c := circuit.GenerateR1CS()
fmt.Println("\nR1CS:") fmt.Println("\nR1CS:")
fmt.Println("a:", a) fmt.Println("a:", a)

View File

@@ -24,12 +24,14 @@ syn keyword goSnarkCircuitPrivatePublic private public
syn keyword goSnarkCircuitOut out syn keyword goSnarkCircuitOut out
syn keyword goSnarkCircuitEquals equals syn keyword goSnarkCircuitEquals equals
syn keyword goSnarkCircuitFunction func syn keyword goSnarkCircuitFunction func
syn keyword goSnarkCircuitImport import
syn match goSnarkCircuitFuncCall /\<\K\k*\ze\s*(/ syn match goSnarkCircuitFuncCall /\<\K\k*\ze\s*(/
syn keyword goSnarkCircuitPrivate private nextgroup=goSnarkCircuitInputName skipwhite syn keyword goSnarkCircuitPrivate private nextgroup=goSnarkCircuitInputName skipwhite
syn keyword goSnarkCircuitPublic public nextgroup=goSnarkCircuitInputName skipwhite syn keyword goSnarkCircuitPublic public nextgroup=goSnarkCircuitInputName skipwhite
syn match goSnarkCircuitInputName '\i\+' contained syn match goSnarkCircuitInputName '\i\+' contained
syn match goSnarkCircuitBraces "[{}\[\]]" syn match goSnarkCircuitBraces "[{}\[\]]"
syn match goSnarkCircuitParens "[()]" syn match goSnarkCircuitParens "[()]"
syn region goSnarkCircuitPath start=+"+ skip=+\\\\\|\\"+ end=+"\|$+
syn sync fromstart syn sync fromstart
syn sync maxlines=100 syn sync maxlines=100
@@ -44,12 +46,14 @@ hi def link goSnarkCircuitOpSymbols Operator
hi def link goSnarkCircuitFuncCall Function hi def link goSnarkCircuitFuncCall Function
hi def link goSnarkCircuitEquals Identifier hi def link goSnarkCircuitEquals Identifier
hi def link goSnarkCircuitFunction Keyword hi def link goSnarkCircuitFunction Keyword
hi def link goSnarkCircuitImport Keyword
hi def link goSnarkCircuitBraces Function hi def link goSnarkCircuitBraces Function
hi def link goSnarkCircuitPrivate Keyword hi def link goSnarkCircuitPrivate Keyword
hi def link goSnarkCircuitPublic Keyword hi def link goSnarkCircuitPublic Keyword
hi def link goSnarkCircuitInputName Special hi def link goSnarkCircuitInputName Special
hi def link goSnarkCircuitOut Special hi def link goSnarkCircuitOut Special
hi def link goSnarkCircuitPrivatePublic Keyword hi def link goSnarkCircuitPrivatePublic Keyword
hi def link goSnarkCircuitPath String
let b:current_syntax = "go-snark-circuit" let b:current_syntax = "go-snark-circuit"
if main_syntax == 'go-snark-circuit' if main_syntax == 'go-snark-circuit'