diff --git a/README.md b/README.md index d894590..86e573f 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ zkSNARK library implementation in Go - `Succinct Non-Interactive Zero Knowledge for a von Neumann Architecture`, Eli Ben-Sasson, Alessandro Chiesa, Eran Tromer, Madars Virza https://eprint.iacr.org/2013/879.pdf - `Pinocchio: Nearly practical verifiable computation`, Bryan Parno, Craig Gentry, Jon Howell, Mariana Raykova https://eprint.iacr.org/2013/279.pdf -## Caution +## Caution, Warning, etc Implementation of the zkSNARK [Pinocchio protocol](https://eprint.iacr.org/2013/279.pdf) from scratch in Go to understand the concepts. Do not use in production. Not finished, implementing this in my free time to understand it better, so I don't have much time. @@ -27,6 +27,8 @@ Current implementation status: - [x] verify proofs with BN128 pairing - [ ] fix 4th pairing proofs generation & verification: ê(Vkx+piA, piB) == ê(piH, Vkz) * ê(piC, G2) - [ ] move witness calculation outside the setup phase +- [ ] Groth16 +- [ ] multiple optimizations ## Usage diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index 9a0338c..220e989 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -2,6 +2,7 @@ package circuitcompiler import ( "errors" + "fmt" "math/big" "strconv" @@ -16,7 +17,6 @@ type Circuit struct { PrivateInputs []string PublicInputs []string Signals []string - PublicSignals []string Witness []*big.Int Constraints []Constraint R1CS struct { @@ -97,12 +97,12 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { // if existInArray(constraint.Out) { if used[constraint.Out] { - panic(errors.New("out variable already used: " + constraint.Out)) + // panic(errors.New("out variable already used: " + constraint.Out)) + fmt.Println("variable already used") } used[constraint.Out] = true if constraint.Op == "in" { - // TODO constraint.PublicInputs - for i := 0; i < len(constraint.PrivateInputs); i++ { + for i := 0; i <= len(circ.PublicInputs); i++ { aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1))) aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used) bConstraint[0] = big.NewInt(int64(1)) @@ -166,8 +166,14 @@ func (circ *Circuit) CalculateWitness(privateInputs []*big.Int, publicInputs []* } w := r1csqap.ArrayOfBigZeros(len(circ.Signals)) w[0] = big.NewInt(int64(1)) + for i, input := range publicInputs { + fmt.Println(i + 1) + fmt.Println(input) + w[i+1] = input + } for i, input := range privateInputs { - w[i+2] = input + fmt.Println(i + len(publicInputs) + 1) + w[i+len(publicInputs)+1] = input } for _, constraint := range circ.Constraints { if constraint.Op == "in" { diff --git a/circuitcompiler/circuit_test.go b/circuitcompiler/circuit_test.go index beb0625..8293d85 100644 --- a/circuitcompiler/circuit_test.go +++ b/circuitcompiler/circuit_test.go @@ -1,6 +1,7 @@ package circuitcompiler import ( + "encoding/json" "fmt" "math/big" "strings" @@ -21,17 +22,20 @@ func TestCircuitParser(t *testing.T) { m2 = m1 * s1 m3 = m2 + s1 out = m3 + 5 + */ // flat code, where er is expected_result + // equals(s5, s1) + // s1 = s5 * 1 flat := ` - func test(private x, public er): - aux = x*x - y = aux*x - z = x + y - res = z + 5 - equals(er, res) - out = 1 + func test(private s0, public s1): + s2 = s0*s0 + s3 = s2*s0 + s4 = s0 + s3 + s5 = s4 + 5 + s5 = s1 * one + out = 1 * 1 ` parser := NewParser(strings.NewReader(flat)) circuit, err := parser.Parse() @@ -84,4 +88,11 @@ func TestCircuitParser(t *testing.T) { w, err := circuit.CalculateWitness(privateInputs, publicInputs) assert.Nil(t, err) fmt.Println("w", w) + + circuitJson, _ := json.Marshal(circuit) + fmt.Println("circuit:", string(circuitJson)) + + assert.Equal(t, circuit.NPublic, 1) + assert.Equal(t, len(circuit.PublicInputs), 1) + assert.Equal(t, len(circuit.PrivateInputs), 1) } diff --git a/circuitcompiler/parser.go b/circuitcompiler/parser.go index 7533bca..51c9a20 100644 --- a/circuitcompiler/parser.go +++ b/circuitcompiler/parser.go @@ -105,10 +105,10 @@ func (p *Parser) parseLine() (*Constraint, error) { // TODO return c, nil } - if c.Literal == "out" { - // TODO - return c, nil - } + // if c.Literal == "out" { + // // TODO + // return c, nil + // } _, lit = p.scanIgnoreWhitespace() // skip = c.Literal += lit @@ -197,16 +197,26 @@ func (p *Parser) Parse() (*Circuit, error) { if !isVal { circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2) } - if constraint.Out == "out" { - // if Out is "out", put it after first value (one) and before the inputs + // fmt.Println("---") + // fmt.Println(circuit.PublicInputs[0]) + // fmt.Println(constraint.Out) + // fmt.Println(constraint.Out == circuit.PublicInputs[0]) + // fmt.Println("---") + + // if constraint.Out == "out" { + // if Out is "out", put it after first value (one) and before the inputs + // if constraint.Out == circuit.PublicInputs[0] { + if existInArray(circuit.PublicInputs, constraint.Out) { + // if Out is a public signal, put it after first value (one) and before the private inputs if !existInArray(circuit.Signals, constraint.Out) { + // if already don't exists in signal array signalsCopy := copyArray(circuit.Signals) var auxSignals []string auxSignals = append(auxSignals, signalsCopy[0]) auxSignals = append(auxSignals, constraint.Out) auxSignals = append(auxSignals, signalsCopy[1:]...) circuit.Signals = auxSignals - circuit.PublicSignals = append(circuit.PublicSignals, constraint.Out) + // circuit.PublicInputs = append(circuit.PublicInputs, constraint.Out) circuit.NPublic++ } } else { diff --git a/snark.go b/snark.go index 581fc6c..e61191a 100644 --- a/snark.go +++ b/snark.go @@ -1,7 +1,6 @@ package snark import ( - "bytes" "fmt" "math/big" "os" @@ -96,15 +95,15 @@ func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, al var err error // input soundness - for i := 0; i < len(alphas); i++ { - for j := 0; j < len(alphas[i]); j++ { - if j <= circuit.NPublic { - if bytes.Equal(alphas[i][j].Bytes(), Utils.FqR.Zero().Bytes()) { - alphas[i][j] = Utils.FqR.One() - } - } - } - } + // for i := 0; i < len(alphas); i++ { + // for j := 0; j < len(alphas[i]); j++ { + // if j <= circuit.NPublic { + // if bytes.Equal(alphas[i][j].Bytes(), Utils.FqR.Zero().Bytes()) { + // alphas[i][j] = Utils.FqR.One() + // } + // } + // } + // } fmt.Println("alphas[1]", alphas[1]) @@ -217,7 +216,8 @@ func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, al // z pol zpol := []*big.Int{big.NewInt(int64(1))} - for i := 1; i < len(circuit.Constraints); i++ { + // for i := 0; i < len(circuit.Constraints); i++ { + for i := 1; i < len(alphas)-1; i++ { zpol = Utils.PF.Mul( zpol, []*big.Int{ diff --git a/snark_test.go b/snark_test.go index 0fe0de6..d416bfe 100644 --- a/snark_test.go +++ b/snark_test.go @@ -1,6 +1,7 @@ package snark import ( + "bytes" "encoding/json" "fmt" "math/big" @@ -14,14 +15,19 @@ import ( ) func TestZkFromFlatCircuitCode(t *testing.T) { - // compile circuit and get the R1CS + + // circuit function + // y = x^3 + x + 5 flatCode := ` - func test(x): - aux = x*x - y = aux*x - z = x + y - out = z + 5 + func test(private s0, public s1): + s2 = s0 * s0 + s3 = s2 * s0 + s4 = s3 + s0 + s5 = s4 + 5 + s1 = s5 * 1 + s5 = s1 * 1 + out = 1 * 1 ` fmt.Print("\nflat code of the circuit:") fmt.Println(flatCode) @@ -36,10 +42,14 @@ func TestZkFromFlatCircuitCode(t *testing.T) { b3 := big.NewInt(int64(3)) privateInputs := []*big.Int{b3} + b35 := big.NewInt(int64(35)) + publicSignals := []*big.Int{b35} + // wittness - w, err := circuit.CalculateWitness(privateInputs) + w, err := circuit.CalculateWitness(privateInputs, publicSignals) assert.Nil(t, err) - fmt.Println("\nwitness", w) + fmt.Println("\n", circuit.Signals) + fmt.Println("witness", w) // flat code to R1CS fmt.Println("\ngenerating R1CS from flat code") @@ -58,6 +68,7 @@ func TestZkFromFlatCircuitCode(t *testing.T) { fmt.Println("betas", len(betas)) fmt.Println("gammas", len(gammas)) fmt.Println("zx length", len(zxQAP)) + assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes())) ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) fmt.Println("ax length", len(ax)) @@ -65,9 +76,6 @@ func TestZkFromFlatCircuitCode(t *testing.T) { fmt.Println("cx length", len(cx)) fmt.Println("px length", len(px)) fmt.Println("px[last]", px[0]) - px0 := Utils.PF.F.Add(px[0], big.NewInt(int64(88))) - fmt.Println(px0) - assert.Equal(t, px0.Bytes(), Utils.PF.F.Zero().Bytes()) hxQAP := Utils.PF.DivisorPolynomial(px, zxQAP) fmt.Println("hx length", len(hxQAP)) @@ -83,7 +91,7 @@ func TestZkFromFlatCircuitCode(t *testing.T) { div, rem := Utils.PF.Div(px, zxQAP) assert.Equal(t, hxQAP, div) - assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4)) + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6)) // calculate trusted setup setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas) @@ -97,6 +105,9 @@ func TestZkFromFlatCircuitCode(t *testing.T) { hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) fmt.Println("hx pk.z", hx) // assert.Equal(t, hxQAP, hx) + div, rem = Utils.PF.Div(px, setup.Pk.Z) + assert.Equal(t, hx, div) + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6)) assert.Equal(t, px, Utils.PF.Mul(hxQAP, zxQAP)) // hx==px/zx so px==hx*zx @@ -117,13 +128,16 @@ func TestZkFromFlatCircuitCode(t *testing.T) { // fmt.Println("public signals:", proof.PublicSignals) fmt.Println("\nwitness", w) - // b1 := big.NewInt(int64(1)) - b35 := big.NewInt(int64(35)) - // publicSignals := []*big.Int{b1, b35} - publicSignals := []*big.Int{b35} + b35Verif := big.NewInt(int64(35)) + publicSignalsVerif := []*big.Int{b35Verif} before := time.Now() - assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true)) + assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true)) fmt.Println("verify proof time elapsed:", time.Since(before)) + + // check that with another public input the verification returns false + bOtherWrongPublic := big.NewInt(int64(34)) + wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic} + assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true)) } /*