diff --git a/groth16/groth16.go b/groth16/groth16.go index 3b2e396..03518f1 100644 --- a/groth16/groth16.go +++ b/groth16/groth16.go @@ -3,6 +3,7 @@ package groth16 import ( + "fmt" "math/big" "github.com/arnaucube/go-snark/bn128" @@ -53,8 +54,8 @@ type Setup struct { } } -// ProofGroth contains the parameters to proof the zkSNARK -type ProofGroth struct { +// Proof contains the parameters to proof the zkSNARK +type Proof struct { PiA [3]*big.Int PiB [3][2]*big.Int PiC [3]*big.Int @@ -216,3 +217,86 @@ func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, al return setup, nil } + +// GenerateProofs generates all the parameters to proof the zkSNARK from the Circuit, Setup and the Witness +func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, w []*big.Int, px []*big.Int) (Proof, error) { + var proof Proof + proof.PiA = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} + proof.PiB = Utils.Bn.Fq6.Zero() + proof.PiC = [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} + + r, err := Utils.FqR.Rand() + if err != nil { + return Proof{}, err + } + s, err := Utils.FqR.Rand() + if err != nil { + return Proof{}, err + } + + // piBG1 will hold all the same than proof.PiB but in G1 curve + piBG1 := [3]*big.Int{Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero(), Utils.Bn.G1.F.Zero()} + + for i := 0; i < circuit.NVars; i++ { + proof.PiA = Utils.Bn.G1.Add(proof.PiA, Utils.Bn.G1.MulScalar(setup.Pk.G1.At[i], w[i])) + piBG1 = Utils.Bn.G1.Add(piBG1, Utils.Bn.G1.MulScalar(setup.Pk.G1.BACGamma[i], w[i])) + proof.PiB = Utils.Bn.G2.Add(proof.PiB, Utils.Bn.G2.MulScalar(setup.Pk.G2.BACGamma[i], w[i])) + } + for i := circuit.NPublic + 1; i < circuit.NVars; i++ { + proof.PiC = Utils.Bn.G1.Add(proof.PiC, Utils.Bn.G1.MulScalar(setup.Pk.BACDelta[i], w[i])) + } + + // piA = (Σ from 0 to m (pk.A * w[i])) + pk.Alpha1 + r * δ + proof.PiA = Utils.Bn.G1.Add(proof.PiA, setup.Pk.G1.Alpha) + deltaR := Utils.Bn.G1.MulScalar(setup.Pk.G1.Delta, r) + proof.PiA = Utils.Bn.G1.Add(proof.PiA, deltaR) + + // piBG1 = (Σ from 0 to m (pk.B1 * w[i])) + pk.g1.Beta + s * δ + // piB = piB2 = (Σ from 0 to m (pk.B2 * w[i])) + pk.g2.Beta + s * δ + piBG1 = Utils.Bn.G1.Add(piBG1, setup.Pk.G1.Beta) + proof.PiB = Utils.Bn.G2.Add(proof.PiB, setup.Pk.G2.Beta) + deltaSG1 := Utils.Bn.G1.MulScalar(setup.Pk.G1.Delta, s) + piBG1 = Utils.Bn.G1.Add(piBG1, deltaSG1) + deltaSG2 := Utils.Bn.G2.MulScalar(setup.Pk.G2.Delta, s) + proof.PiB = Utils.Bn.G2.Add(proof.PiB, deltaSG2) + + hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) // maybe move this calculation to a previous step + + // piC = (Σ from l+1 to m (w[i] * (pk.g1.Beta + pk.g1.Alpha + pk.C)) + h(tau)) / δ) + piA*s + r*piB - r*s*δ + for i := 0; i < len(hx); i++ { + proof.PiC = Utils.Bn.G1.Add(proof.PiC, Utils.Bn.G1.MulScalar(setup.Pk.PowersTauDelta[i], hx[i])) + } + proof.PiC = Utils.Bn.G1.Add(proof.PiC, Utils.Bn.G1.MulScalar(proof.PiA, s)) + proof.PiC = Utils.Bn.G1.Add(proof.PiC, Utils.Bn.G1.MulScalar(piBG1, r)) + negRS := Utils.FqR.Neg(Utils.FqR.Mul(r, s)) + proof.PiC = Utils.Bn.G1.Add(proof.PiC, Utils.Bn.G1.MulScalar(setup.Pk.G1.Delta, negRS)) + + return proof, nil +} + +// VerifyProof verifies over the BN128 the Pairings of the Proof +func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publicSignals []*big.Int, debug bool) bool { + + icPubl := setup.Vk.IC[0] + for i := 0; i < len(publicSignals); i++ { + icPubl = Utils.Bn.G1.Add(icPubl, Utils.Bn.G1.MulScalar(setup.Vk.IC[i+1], publicSignals[i])) + } + + if !Utils.Bn.Fq12.Equal( + Utils.Bn.Pairing(proof.PiA, proof.PiB), + Utils.Bn.Fq12.Mul( + Utils.Bn.Pairing(setup.Vk.G1.Alpha, setup.Vk.G2.Beta), + Utils.Bn.Fq12.Mul( + Utils.Bn.Pairing(icPubl, setup.Vk.G2.Gamma), + Utils.Bn.Pairing(proof.PiC, setup.Vk.G2.Delta)))) { + if debug { + fmt.Println("❌ groth16 verification not passed") + } + return false + } + if debug { + fmt.Println("✓ groth16 verification passed") + } + + return true +} diff --git a/groth16/groth16_test.go b/groth16/groth16_test.go new file mode 100644 index 0000000..fc91226 --- /dev/null +++ b/groth16/groth16_test.go @@ -0,0 +1,107 @@ +package groth16 + +import ( + "bytes" + "fmt" + "math/big" + "strings" + "testing" + "time" + + "github.com/arnaucube/go-snark/circuitcompiler" + "github.com/arnaucube/go-snark/r1csqap" + "github.com/stretchr/testify/assert" +) + +func TestGroth16MinimalFlow(t *testing.T) { + fmt.Println("testing Groth16 minimal flow") + // circuit function + // y = x^3 + x + 5 + code := ` + func main(private s0, public s1): + s2 = s0 * s0 + s3 = s2 * s0 + s4 = s3 + s0 + s5 = s4 + 5 + equals(s1, s5) + out = 1 * 1 + ` + fmt.Print("\ncode of the circuit:") + + // parse the code + parser := circuitcompiler.NewParser(strings.NewReader(code)) + circuit, err := parser.Parse() + assert.Nil(t, err) + + b3 := big.NewInt(int64(3)) + privateInputs := []*big.Int{b3} + b35 := big.NewInt(int64(35)) + publicSignals := []*big.Int{b35} + + // wittness + w, err := circuit.CalculateWitness(privateInputs, publicSignals) + assert.Nil(t, err) + + // code to R1CS + fmt.Println("\ngenerating R1CS from code") + a, b, c := circuit.GenerateR1CS() + fmt.Println("\nR1CS:") + fmt.Println("a:", a) + fmt.Println("b:", b) + fmt.Println("c:", c) + + // R1CS to QAP + // TODO zxQAP is not used and is an old impl, TODO remove + alphas, betas, gammas, _ := Utils.PF.R1CSToQAP(a, b, c) + fmt.Println("qap") + assert.Equal(t, 8, len(alphas)) + assert.Equal(t, 8, len(alphas)) + assert.Equal(t, 8, len(alphas)) + assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes())) + + ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) + assert.Equal(t, 7, len(ax)) + assert.Equal(t, 7, len(bx)) + assert.Equal(t, 7, len(cx)) + assert.Equal(t, 13, len(px)) + + // --- + // from here is the GROTH16 + // --- + // calculate trusted setup + fmt.Println("groth") + setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas) + assert.Nil(t, err) + fmt.Println("\nt:", setup.Toxic.T) + + hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) + div, rem := Utils.PF.Div(px, setup.Pk.Z) + assert.Equal(t, hx, div) + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6)) + + // hx==px/zx so px==hx*zx + assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z)) + + // check length of polynomials H(x) and Z(x) + assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1) + + proof, err := GenerateProofs(*circuit, setup, w, px) + assert.Nil(t, err) + + // fmt.Println("\n proofs:") + // fmt.Println(proof) + + // fmt.Println("public signals:", proof.PublicSignals) + fmt.Println("\nsignals:", circuit.Signals) + fmt.Println("witness:", w) + b35Verif := big.NewInt(int64(35)) + publicSignalsVerif := []*big.Int{b35Verif} + before := time.Now() + assert.True(t, VerifyProof(*circuit, setup, proof, publicSignalsVerif, true)) + fmt.Println("verify proof time elapsed:", time.Since(before)) + + // check that with another public input the verification returns false + bOtherWrongPublic := big.NewInt(int64(34)) + wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic} + assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, false)) +} diff --git a/snark.go b/snark.go index 642fc54..9bfe5cb 100644 --- a/snark.go +++ b/snark.go @@ -1,3 +1,5 @@ +// implementation of https://eprint.iacr.org/2013/879.pdf + package snark import ( @@ -289,7 +291,9 @@ func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publ pairingPiaVa := Utils.Bn.Pairing(proof.PiA, setup.Vk.Vka) pairingPiapG2 := Utils.Bn.Pairing(proof.PiAp, Utils.Bn.G2.G) if !Utils.Bn.Fq12.Equal(pairingPiaVa, pairingPiapG2) { - fmt.Println("❌ e(piA, Va) == e(piA', g2), valid knowledge commitment for A") + if debug { + fmt.Println("❌ e(piA, Va) == e(piA', g2), valid knowledge commitment for A") + } return false } if debug { @@ -300,7 +304,9 @@ func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publ pairingVbPib := Utils.Bn.Pairing(setup.Vk.Vkb, proof.PiB) pairingPibpG2 := Utils.Bn.Pairing(proof.PiBp, Utils.Bn.G2.G) if !Utils.Bn.Fq12.Equal(pairingVbPib, pairingPibpG2) { - fmt.Println("❌ e(Vb, piB) == e(piB', g2), valid knowledge commitment for B") + if debug { + fmt.Println("❌ e(Vb, piB) == e(piB', g2), valid knowledge commitment for B") + } return false } if debug { @@ -311,7 +317,9 @@ func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publ pairingPicVc := Utils.Bn.Pairing(proof.PiC, setup.Vk.Vkc) pairingPicpG2 := Utils.Bn.Pairing(proof.PiCp, Utils.Bn.G2.G) if !Utils.Bn.Fq12.Equal(pairingPicVc, pairingPicpG2) { - fmt.Println("❌ e(piC, Vc) == e(piC', g2), valid knowledge commitment for C") + if debug { + fmt.Println("❌ e(piC, Vc) == e(piC', g2), valid knowledge commitment for C") + } return false } if debug { @@ -330,7 +338,9 @@ func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publ Utils.Bn.Fq12.Mul( Utils.Bn.Pairing(proof.PiH, setup.Vk.Vkz), Utils.Bn.Pairing(proof.PiC, Utils.Bn.G2.G))) { - fmt.Println("❌ e(Vkx+piA, piB) == e(piH, Vkz) * e(piC, g2), QAP disibility checked") + if debug { + fmt.Println("❌ e(Vkx+piA, piB) == e(piH, Vkz) * e(piC, g2), QAP disibility checked") + } return false } if debug { diff --git a/snark_test.go b/snark_test.go index 9fdfeae..4cb4a85 100644 --- a/snark_test.go +++ b/snark_test.go @@ -9,10 +9,104 @@ import ( "time" "github.com/arnaucube/go-snark/circuitcompiler" + "github.com/arnaucube/go-snark/groth16" "github.com/arnaucube/go-snark/r1csqap" "github.com/stretchr/testify/assert" ) +func TestGroth16MinimalFlow(t *testing.T) { + fmt.Println("testing Groth16 minimal flow") + // circuit function + // y = x^3 + x + 5 + code := ` + func main(private s0, public s1): + s2 = s0 * s0 + s3 = s2 * s0 + s4 = s3 + s0 + s5 = s4 + 5 + equals(s1, s5) + out = 1 * 1 + ` + fmt.Print("\ncode of the circuit:") + + // parse the code + parser := circuitcompiler.NewParser(strings.NewReader(code)) + circuit, err := parser.Parse() + assert.Nil(t, err) + + b3 := big.NewInt(int64(3)) + privateInputs := []*big.Int{b3} + b35 := big.NewInt(int64(35)) + publicSignals := []*big.Int{b35} + + // wittness + w, err := circuit.CalculateWitness(privateInputs, publicSignals) + assert.Nil(t, err) + + // code to R1CS + fmt.Println("\ngenerating R1CS from code") + a, b, c := circuit.GenerateR1CS() + fmt.Println("\nR1CS:") + fmt.Println("a:", a) + fmt.Println("b:", b) + fmt.Println("c:", c) + + // R1CS to QAP + // TODO zxQAP is not used and is an old impl, TODO remove + alphas, betas, gammas, _ := Utils.PF.R1CSToQAP(a, b, c) + fmt.Println("qap") + assert.Equal(t, 8, len(alphas)) + assert.Equal(t, 8, len(alphas)) + assert.Equal(t, 8, len(alphas)) + assert.True(t, !bytes.Equal(alphas[1][1].Bytes(), big.NewInt(int64(0)).Bytes())) + + ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) + assert.Equal(t, 7, len(ax)) + assert.Equal(t, 7, len(bx)) + assert.Equal(t, 7, len(cx)) + assert.Equal(t, 13, len(px)) + + // --- + // from here is the GROTH16 + // --- + // calculate trusted setup + fmt.Println("groth") + setup, err := groth16.GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas) + assert.Nil(t, err) + fmt.Println("\nt:", setup.Toxic.T) + + hx := Utils.PF.DivisorPolynomial(px, setup.Pk.Z) + div, rem := Utils.PF.Div(px, setup.Pk.Z) + assert.Equal(t, hx, div) + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(6)) + + // hx==px/zx so px==hx*zx + assert.Equal(t, px, Utils.PF.Mul(hx, setup.Pk.Z)) + + // check length of polynomials H(x) and Z(x) + assert.Equal(t, len(hx), len(px)-len(setup.Pk.Z)+1) + + proof, err := groth16.GenerateProofs(*circuit, setup, w, px) + assert.Nil(t, err) + + // fmt.Println("\n proofs:") + // fmt.Println(proof) + + // fmt.Println("public signals:", proof.PublicSignals) + fmt.Println("\nsignals:", circuit.Signals) + fmt.Println("witness:", w) + b35Verif := big.NewInt(int64(35)) + publicSignalsVerif := []*big.Int{b35Verif} + before := time.Now() + assert.True(t, groth16.VerifyProof(*circuit, setup, proof, publicSignalsVerif, true)) + fmt.Println("verify proof time elapsed:", time.Since(before)) + + // check that with another public input the verification returns false + bOtherWrongPublic := big.NewInt(int64(34)) + wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic} + assert.True(t, !groth16.VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, false)) +} + func TestZkFromFlatCircuitCode(t *testing.T) { // compile circuit and get the R1CS @@ -145,7 +239,7 @@ func TestZkFromFlatCircuitCode(t *testing.T) { // check that with another public input the verification returns false bOtherWrongPublic := big.NewInt(int64(34)) wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic} - assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true)) + assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, false)) } func TestZkMultiplication(t *testing.T) { @@ -253,7 +347,7 @@ func TestZkMultiplication(t *testing.T) { // check that with another public input the verification returns false bOtherWrongPublic := big.NewInt(int64(11)) wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic} - assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true)) + assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, false)) } func TestMinimalFlow(t *testing.T) { @@ -342,5 +436,5 @@ func TestMinimalFlow(t *testing.T) { // check that with another public input the verification returns false bOtherWrongPublic := big.NewInt(int64(34)) wrongPublicSignalsVerif := []*big.Int{bOtherWrongPublic} - assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, true)) + assert.True(t, !VerifyProof(*circuit, setup, proof, wrongPublicSignalsVerif, false)) }