From 0806af6b803d7325aa8c4896324b89dbab39634e Mon Sep 17 00:00:00 2001 From: arnaucube Date: Wed, 26 Dec 2018 16:40:05 +0100 Subject: [PATCH] flat circuit code to R1CS working --- README.md | 34 ++++++++---- circuitcompiler/circuit.go | 94 ++++++++++++++++++++++++++++++--- circuitcompiler/circuit_test.go | 37 ++++++++++++- circuitcompiler/parser.go | 68 ++++++++++++++++++------ snark_test.go | 70 +++++++++++++++++++++++- 5 files changed, 267 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 5f2a46e..52f2415 100644 --- a/README.md +++ b/README.md @@ -30,24 +30,36 @@ fqR := fields.NewFq(bn.R) // new Polynomial Field pf := r1csqap.NewPolynomialField(f) -/* -suppose that we have the following variables with *big.Int elements: -a = [[0 1 0 0 0 0] [0 0 0 1 0 0] [0 1 0 0 1 0] [5 0 0 0 0 1]] -b = [[0 1 0 0 0 0] [0 1 0 0 0 0] [1 0 0 0 0 0] [1 0 0 0 0 0]] -c = [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]] +// compile circuit and get the R1CS +flatCode := ` +func test(x): + aux = x*x + y = aux*x + z = x + y + out = z + 5 +` +// parse the code +parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) +circuit, err := parser.Parse() +assert.Nil(t, err) +fmt.Println(circuit) +// flat code to R1CS +fmt.Println("generating R1CS from flat code") +a, b, c := circuit.GenerateR1CS() -w = [1, 3, 35, 9, 27, 30] +/* +now we have the R1CS from the circuit: +a == [[0 1 0 0 0 0] [0 0 0 1 0 0] [0 1 0 0 1 0] [5 0 0 0 0 1]] +b == [[0 1 0 0 0 0] [0 1 0 0 0 0] [1 0 0 0 0 0] [1 0 0 0 0 0]] +c == [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]] */ + alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) // wittness = 1, 3, 35, 9, 27, 30 w := []*big.Int{b1, b3, b35, b9, b27, b30} -circuit := compiler.Circuit{ - NVars: 6, - NPublic: 0, - NSignals: len(w), -} + ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas) hx := pf.DivisorPolinomial(px, zx) diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index 3f2e14a..9143551 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -1,8 +1,11 @@ package circuitcompiler import ( - "fmt" + "errors" "math/big" + "strconv" + + "github.com/arnaucube/go-snark/r1csqap" ) type Circuit struct { @@ -19,14 +22,89 @@ type Circuit struct { C [][]*big.Int } } +type Constraint struct { + // v1 op v2 = out + Op string + V1 string + V2 string + Out string + Literal string + + Inputs []string // in func delcaration case +} + +func indexInArray(arr []string, e string) int { + for i, a := range arr { + if a == e { + return i + } + } + return -1 +} +func isValue(a string) (bool, int) { + v, err := strconv.Atoi(a) + if err != nil { + return false, 0 + } + return true, v +} +func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) { + isVal, value := isValue(v) + valueBigInt := big.NewInt(int64(value)) + if isVal { + arr[0] = new(big.Int).Add(arr[0], valueBigInt) + } else { + if !used[v] { + panic(errors.New("using variable before it's set")) + } + arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1))) + } + return arr, used +} + +func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { + // from flat code to R1CS + + var a [][]*big.Int + var b [][]*big.Int + var c [][]*big.Int + + used := make(map[string]bool) + for _, constraint := range circ.Constraints { + + aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals)) + bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals)) + cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals)) + + // if existInArray(constraint.Out) { + if used[constraint.Out] { + panic(errors.New("out variable already used: " + constraint.Out)) + } + used[constraint.Out] = true + if constraint.Op == "in" { + for i := 0; i < len(constraint.Inputs); i++ { + aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1))) + aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used) + bConstraint[0] = big.NewInt(int64(1)) + + } + continue + + } else if constraint.Op == "+" { + cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1)) + aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used) + aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used) + bConstraint[0] = big.NewInt(int64(1)) + } else if constraint.Op == "*" { + cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1)) + aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used) + bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used) + } -func (c *Circuit) GenerateR1CS() { - fmt.Print("function with inputs: ") - fmt.Println(c.Inputs) - fmt.Print("signals: ") - fmt.Println(c.Signals) - for _, constraint := range c.Constraints { - fmt.Println(constraint.Literal) + a = append(a, aConstraint) + b = append(b, bConstraint) + c = append(c, cConstraint) } + return a, b, c } diff --git a/circuitcompiler/circuit_test.go b/circuitcompiler/circuit_test.go index b61d8a8..9cb183a 100644 --- a/circuitcompiler/circuit_test.go +++ b/circuitcompiler/circuit_test.go @@ -2,6 +2,7 @@ package circuitcompiler import ( "fmt" + "math/big" "strings" "testing" @@ -37,6 +38,40 @@ func TestCircuitParser(t *testing.T) { // flat code to R1CS fmt.Println("generating R1CS from flat code") - circuit.GenerateR1CS() + a, b, c := circuit.GenerateR1CS() + fmt.Print("function with inputs: ") fmt.Println(circuit.Inputs) + + fmt.Print("signals: ") + fmt.Println(circuit.Signals) + + // expected result + b0 := big.NewInt(int64(0)) + b1 := big.NewInt(int64(1)) + b5 := big.NewInt(int64(5)) + aExpected := [][]*big.Int{ + []*big.Int{b0, b1, b0, b0, b0, b0}, + []*big.Int{b0, b0, b0, b1, b0, b0}, + []*big.Int{b0, b1, b0, b0, b1, b0}, + []*big.Int{b5, b0, b0, b0, b0, b1}, + } + bExpected := [][]*big.Int{ + []*big.Int{b0, b1, b0, b0, b0, b0}, + []*big.Int{b0, b1, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0}, + } + cExpected := [][]*big.Int{ + []*big.Int{b0, b0, b0, b1, b0, b0}, + []*big.Int{b0, b0, b0, b0, b1, b0}, + []*big.Int{b0, b0, b0, b0, b0, b1}, + []*big.Int{b0, b0, b1, b0, b0, b0}, + } + + assert.Equal(t, aExpected, a) + assert.Equal(t, bExpected, b) + assert.Equal(t, cExpected, c) + fmt.Println(a) + fmt.Println(b) + fmt.Println(c) } diff --git a/circuitcompiler/parser.go b/circuitcompiler/parser.go index 972fd7b..cc94700 100644 --- a/circuitcompiler/parser.go +++ b/circuitcompiler/parser.go @@ -16,17 +16,6 @@ type Parser struct { } } -type Constraint struct { - // v1 op v2 = out - Op Token - V1 string - V2 string - Out string - Literal string - - Inputs []string // in func delcaration case -} - func NewParser(r io.Reader) *Parser { return &Parser{s: NewScanner(r)} } @@ -90,7 +79,8 @@ func (p *Parser) ParseLine() (*Constraint, error) { c.V1 = lit c.Literal += lit // operator - c.Op, lit = p.scanIgnoreWhitespace() + _, lit = p.scanIgnoreWhitespace() + c.Op = lit c.Literal += lit // v2 _, lit = p.scanIgnoreWhitespace() @@ -102,6 +92,15 @@ func (p *Parser) ParseLine() (*Constraint, error) { return c, nil } +func existInArray(arr []string, elem string) bool { + for _, v := range arr { + if v == elem { + return true + } + } + return false +} + func addToArrayIfNotExist(arr []string, elem string) []string { for _, v := range arr { if v == elem { @@ -111,22 +110,61 @@ func addToArrayIfNotExist(arr []string, elem string) []string { arr = append(arr, elem) return arr } + func (p *Parser) Parse() (*Circuit, error) { circuit := &Circuit{} circuit.Signals = append(circuit.Signals, "one") + nInputs := 0 for { constraint, err := p.ParseLine() if err != nil { break } if constraint.Literal == "func" { + // one constraint for each input + for _, in := range constraint.Inputs { + newConstr := &Constraint{ + Op: "in", + Out: in, + } + circuit.Constraints = append(circuit.Constraints, *newConstr) + nInputs++ + } circuit.Inputs = constraint.Inputs continue } circuit.Constraints = append(circuit.Constraints, *constraint) - circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1) - circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2) - circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out) + isVal, _ := isValue(constraint.V1) + if !isVal { + circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1) + } + isVal, _ = isValue(constraint.V2) + if !isVal { + circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2) + } + if constraint.Out == "out" { + // if Out is "out", put it after the inputs + if !existInArray(circuit.Signals, constraint.Out) { + signalsCopy := copyArray(circuit.Signals) + var auxSignals []string + auxSignals = append(auxSignals, signalsCopy[0:nInputs+1]...) + auxSignals = append(auxSignals, constraint.Out) + auxSignals = append(auxSignals, signalsCopy[nInputs+1:]...) + circuit.Signals = auxSignals + } + } else { + circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out) + } } + circuit.NVars = len(circuit.Signals) + circuit.NSignals = len(circuit.Signals) + circuit.NPublic = 0 return circuit, nil } +func copyArray(in []string) []string { // tmp + var out []string + for _, e := range in { + out = append(out, e) + } + return out +} diff --git a/snark_test.go b/snark_test.go index 29121cc..8a166c9 100644 --- a/snark_test.go +++ b/snark_test.go @@ -3,6 +3,7 @@ package snark import ( "fmt" "math/big" + "strings" "testing" "github.com/arnaucube/go-snark/bn128" @@ -12,7 +13,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestZk(t *testing.T) { +func TestZkFromHardcodedR1CS(t *testing.T) { bn, err := bn128.NewBn128() assert.Nil(t, err) @@ -85,3 +86,70 @@ func TestZk(t *testing.T) { assert.True(t, VerifyProof(bn, circuit, setup, proof)) } + +func TestZkFromFlatCircuitCode(t *testing.T) { + bn, err := bn128.NewBn128() + assert.Nil(t, err) + + // new Finite Field + fqR := fields.NewFq(bn.R) + + // new Polynomial Field + pf := r1csqap.NewPolynomialField(fqR) + + // compile circuit and get the R1CS + flatCode := ` + func test(x): + aux = x*x + y = aux*x + z = x + y + out = z + 5 + ` + // parse the code + parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) + circuit, err := parser.Parse() + assert.Nil(t, err) + fmt.Println(circuit) + // flat code to R1CS + fmt.Println("generating R1CS from flat code") + a, b, c := circuit.GenerateR1CS() + + alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) + + // wittness = 1, 3, 35, 9, 27, 30 + b1 := big.NewInt(int64(1)) + b3 := big.NewInt(int64(3)) + b9 := big.NewInt(int64(9)) + b27 := big.NewInt(int64(27)) + b30 := big.NewInt(int64(30)) + b35 := big.NewInt(int64(35)) + w := []*big.Int{b1, b3, b35, b9, b27, b30} + + ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas) + + hx := pf.DivisorPolinomial(px, zx) + + // hx==px/zx so px==hx*zx + assert.Equal(t, px, pf.Mul(hx, zx)) + + // p(x) = a(x) * b(x) - c(x) == h(x) * z(x) + abc := pf.Sub(pf.Mul(ax, bx), cx) + assert.Equal(t, abc, px) + hz := pf.Mul(hx, zx) + assert.Equal(t, abc, hz) + + div, rem := pf.Div(px, zx) + assert.Equal(t, hx, div) + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4)) + + // calculate trusted setup + setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), *circuit, alphas, betas, gammas, zx) + assert.Nil(t, err) + fmt.Println("t", setup.Toxic.T) + + // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) + proof, err := GenerateProofs(bn, fqR, *circuit, setup, hx, w) + assert.Nil(t, err) + + assert.True(t, VerifyProof(bn, *circuit, setup, proof)) +}