diff --git a/.gitignore b/.gitignore index 4f61f39..26876d4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ cli/inputs.json cli/proofs.json cli/test.circuit cli/trustedsetup.json +tmp diff --git a/README.md b/README.md index 31fe644..2db9d3d 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Current implementation status: - [x] G1 and G2 curve operations - [x] BN128 Pairing - [x] circuit code compiler - - [ ] code to flat code + - [ ] code to flat code (improve circuit compiler) - [x] flat code compiler - [x] circuit to R1CS - [x] polynomial operations @@ -24,6 +24,8 @@ Current implementation status: - [x] generate trusted setup - [x] generate proofs - [x] verify proofs with BN128 pairing + - [ ] fix 4th pairing proofs generation & verification +- [ ] WASM implementation to run on browsers ## Usage diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index 8c6d9f8..9cb0070 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -131,6 +131,9 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { c = append(c, cConstraint) } + circ.R1CS.A = a + circ.R1CS.B = b + circ.R1CS.C = c return a, b, c } @@ -144,6 +147,11 @@ func grabVar(signals []string, w []*big.Int, vStr string) *big.Int { } } +type Inputs struct { + Private []*big.Int + Publics []*big.Int +} + // CalculateWitness calculates the Witness of a Circuit based on the given inputs func (circ *Circuit) CalculateWitness(inputs []*big.Int) ([]*big.Int, error) { if len(inputs) != len(circ.Inputs) { diff --git a/cli/go-snark-cli b/cli/go-snark-cli deleted file mode 100755 index 7798b74..0000000 Binary files a/cli/go-snark-cli and /dev/null differ diff --git a/cli/main.go b/cli/main.go index 6768620..0557efc 100644 --- a/cli/main.go +++ b/cli/main.go @@ -30,6 +30,12 @@ var commands = []cli.Command{ Usage: "compile a circuit", Action: CompileCircuit, }, + { + Name: "trustedsetup", + Aliases: []string{}, + Usage: "generate trusted setup for a circuit", + Action: TrustedSetup, + }, { Name: "genproofs", Aliases: []string{}, @@ -79,12 +85,13 @@ func CompileCircuit(context *cli.Context) error { panicErr(err) // parse inputs from inputsFile - var inputs []*big.Int + // var inputs []*big.Int + var inputs circuitcompiler.Inputs json.Unmarshal([]byte(string(inputsFile)), &inputs) panicErr(err) // calculate wittness - w, err := circuit.CalculateWitness(inputs) + w, err := circuit.CalculateWitness(inputs.Private) panicErr(err) fmt.Println("\nwitness", w) @@ -137,11 +144,6 @@ func CompileCircuit(context *cli.Context) error { } } - // calculate trusted setup - setup, err := snark.GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx) - panicErr(err) - fmt.Println("\nt:", setup.Toxic.T) - // store circuit to json jsonData, err := json.Marshal(circuit) panicErr(err) @@ -153,6 +155,41 @@ func CompileCircuit(context *cli.Context) error { jsonFile.Close() fmt.Println("Compiled Circuit data written to ", jsonFile.Name()) + return nil +} + +func TrustedSetup(context *cli.Context) error { + // open compiledcircuit.json + compiledcircuitFile, err := ioutil.ReadFile("compiledcircuit.json") + panicErr(err) + var circuit circuitcompiler.Circuit + json.Unmarshal([]byte(string(compiledcircuitFile)), &circuit) + panicErr(err) + + // read inputs file + inputsFile, err := ioutil.ReadFile("inputs.json") + panicErr(err) + // parse inputs from inputsFile + // var inputs []*big.Int + var inputs circuitcompiler.Inputs + json.Unmarshal([]byte(string(inputsFile)), &inputs) + panicErr(err) + // calculate wittness + w, err := circuit.CalculateWitness(inputs.Private) + panicErr(err) + + // R1CS to QAP + alphas, betas, gammas, zx := snark.Utils.PF.R1CSToQAP(circuit.R1CS.A, circuit.R1CS.B, circuit.R1CS.C) + fmt.Println("qap") + fmt.Println(alphas) + fmt.Println(betas) + fmt.Println(gammas) + + // calculate trusted setup + setup, err := snark.GenerateTrustedSetup(len(w), circuit, alphas, betas, gammas, zx) + panicErr(err) + fmt.Println("\nt:", setup.Toxic.T) + // remove setup.Toxic var tsetup snark.Setup tsetup.Pk = setup.Pk @@ -161,10 +198,10 @@ func CompileCircuit(context *cli.Context) error { tsetup.G2T = setup.G2T // store setup to json - jsonData, err = json.Marshal(tsetup) + jsonData, err := json.Marshal(tsetup) panicErr(err) // store setup into file - jsonFile, err = os.Create("trustedsetup.json") + jsonFile, err := os.Create("trustedsetup.json") panicErr(err) defer jsonFile.Close() jsonFile.Write(jsonData) @@ -192,27 +229,34 @@ func GenerateProofs(context *cli.Context) error { inputsFile, err := ioutil.ReadFile("inputs.json") panicErr(err) // parse inputs from inputsFile - var inputs []*big.Int + // var inputs []*big.Int + var inputs circuitcompiler.Inputs json.Unmarshal([]byte(string(inputsFile)), &inputs) panicErr(err) // calculate wittness - w, err := circuit.CalculateWitness(inputs) + w, err := circuit.CalculateWitness(inputs.Private) panicErr(err) fmt.Println("\nwitness", w) // flat code to R1CS - a, b, c := circuit.GenerateR1CS() + // a, b, c := circuit.GenerateR1CS() + a := circuit.R1CS.A + b := circuit.R1CS.B + c := circuit.R1CS.C // R1CS to QAP alphas, betas, gammas, zx := snark.Utils.PF.R1CSToQAP(a, b, c) _, _, _, px := snark.Utils.PF.CombinePolynomials(w, alphas, betas, gammas) hx := snark.Utils.PF.DivisorPolynomial(px, zx) + fmt.Println(circuit) + fmt.Println(trustedsetup.G1T) + fmt.Println(hx) + fmt.Println(w) proof, err := snark.GenerateProofs(circuit, trustedsetup, hx, w) panicErr(err) fmt.Println("\n proofs:") fmt.Println(proof) - fmt.Println("public signals:", proof.PublicSignals) // store proofs to json jsonData, err := json.Marshal(proof) @@ -249,7 +293,8 @@ func VerifyProofs(context *cli.Context) error { json.Unmarshal([]byte(string(trustedsetupFile)), &trustedsetup) panicErr(err) - verified := snark.VerifyProof(circuit, trustedsetup, proof, true) + // TODO read publicSignals from file + verified := snark.VerifyProof(circuit, trustedsetup, proof, publicSignals, true) if !verified { fmt.Println("ERROR: proofs not verified") } else { diff --git a/snark.go b/snark.go index a4449e3..fe7dbe6 100644 --- a/snark.go +++ b/snark.go @@ -1,6 +1,7 @@ package snark import ( + "bytes" "fmt" "math/big" "os" @@ -41,7 +42,7 @@ type Setup struct { Vka [3][2]*big.Int Vkb [3]*big.Int Vkc [3][2]*big.Int - A [][3]*big.Int + IC [][3]*big.Int G1Kbg [3]*big.Int // g1 * Kbeta * Kgamma G2Kbg [3][2]*big.Int // g2 * Kbeta * Kgamma G2Kg [3][2]*big.Int // g2 * Kgamma @@ -51,15 +52,15 @@ type Setup struct { // Proof contains the parameters to proof the zkSNARK type Proof struct { - PiA [3]*big.Int - PiAp [3]*big.Int - PiB [3][2]*big.Int - PiBp [3]*big.Int - PiC [3]*big.Int - PiCp [3]*big.Int - PiH [3]*big.Int - PiKp [3]*big.Int - PublicSignals []*big.Int + PiA [3]*big.Int + PiAp [3]*big.Int + PiB [3][2]*big.Int + PiBp [3]*big.Int + PiC [3]*big.Int + PiCp [3]*big.Int + PiH [3]*big.Int + PiKp [3]*big.Int + // PublicSignals []*big.Int } type utils struct { @@ -92,6 +93,18 @@ func prepareUtils() utils { func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, alphas, betas, gammas [][]*big.Int, zx []*big.Int) (Setup, error) { var setup Setup var err error + + // input soundness + for i := 0; i < len(alphas); i++ { + for j := 0; j < len(alphas[i]); j++ { + if j <= circuit.NPublic { + if bytes.Equal(alphas[i][j].Bytes(), Utils.Bn.Fq1.Zero().Bytes()) { + alphas[i][j] = Utils.Bn.Fq1.One() + } + } + } + } + // generate random t value setup.Toxic.T, err = Utils.FqR.Rand() if err != nil { @@ -136,13 +149,19 @@ func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, al // encrypt t values with curve generators var gt1 [][3]*big.Int var gt2 [][3][2]*big.Int - for i := 0; i < witnessLength; i++ { - tPow := Utils.FqR.Exp(setup.Toxic.T, big.NewInt(int64(i))) - tEncr1 := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, tPow) - gt1 = append(gt1, tEncr1) - tEncr2 := Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, tPow) - gt2 = append(gt2, tEncr2) + gt1 = append(gt1, Utils.Bn.G1.G) + tEncr := setup.Toxic.T + for i := 1; i < witnessLength; i++ { + gt1 = append(gt1, Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, tEncr)) + tEncr = Utils.Bn.Fq1.Mul(tEncr, setup.Toxic.T) } + // for i := 0; i < witnessLength; i++ { + // tPow := Utils.FqR.Exp(setup.Toxic.T, big.NewInt(int64(i))) + // tEncr1 := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, tPow) + // gt1 = append(gt1, tEncr1) + // tEncr2 := Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, tPow) + // gt2 = append(gt2, tEncr2) + // } // gt1: g1, g1*t, g1*t^2, g1*t^3, ... // gt2: g2, g2*t, g2*t^2, ... setup.G1T = gt1 @@ -163,25 +182,27 @@ func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, al setup.Vk.G2Kbg = Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, kbg) setup.Vk.G2Kg = Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, setup.Toxic.Kgamma) - // for i := 0; i < circuit.NSignals; i++ { for i := 0; i < circuit.NVars; i++ { at := Utils.PF.Eval(alphas[i], setup.Toxic.T) - a := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, at) + rhoAat := Utils.Bn.Fq1.Mul(setup.Toxic.RhoA, at) + a := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, rhoAat) setup.Pk.A = append(setup.Pk.A, a) if i <= circuit.NPublic { - setup.Vk.A = append(setup.Vk.A, a) + setup.Vk.IC = append(setup.Vk.IC, a) } bt := Utils.PF.Eval(betas[i], setup.Toxic.T) - bg1 := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, bt) - bg2 := Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, bt) + rhoBbt := Utils.Bn.Fq1.Mul(setup.Toxic.RhoB, bt) + bg1 := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, rhoBbt) + bg2 := Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, rhoBbt) setup.Pk.B = append(setup.Pk.B, bg2) ct := Utils.PF.Eval(gammas[i], setup.Toxic.T) - c := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, ct) + rhoCct := Utils.Bn.Fq1.Mul(setup.Toxic.RhoC, ct) + c := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, rhoCct) setup.Pk.C = append(setup.Pk.C, c) - kt := Utils.FqR.Add(Utils.FqR.Add(at, bt), ct) + kt := Utils.FqR.Add(Utils.FqR.Add(rhoAat, rhoBbt), rhoCct) k := Utils.Bn.G1.Affine(Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, kt)) ktest := Utils.Bn.G1.Affine(Utils.Bn.G1.Add(Utils.Bn.G1.Add(a, bg1), c)) @@ -196,7 +217,9 @@ func GenerateTrustedSetup(witnessLength int, circuit circuitcompiler.Circuit, al k_ := Utils.Bn.G1.MulScalar(Utils.Bn.G1.G, kt) setup.Pk.Kp = append(setup.Pk.Kp, Utils.Bn.G1.MulScalar(k_, setup.Toxic.Kbeta)) } - setup.Vk.Vkz = Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, Utils.PF.Eval(zx, setup.Toxic.T)) + zt := Utils.PF.Eval(zx, setup.Toxic.T) + rhoCzt := Utils.Bn.Fq1.Mul(setup.Toxic.RhoC, zt) + setup.Vk.Vkz = Utils.Bn.G2.MulScalar(Utils.Bn.G2.G, rhoCzt) return setup, nil } @@ -231,13 +254,12 @@ func GenerateProofs(circuit circuitcompiler.Circuit, setup Setup, hx []*big.Int, for i := 0; i < len(hx); i++ { proof.PiH = Utils.Bn.G1.Add(proof.PiH, Utils.Bn.G1.MulScalar(setup.G1T[i], hx[i])) } - proof.PublicSignals = w[1:2] // out signal return proof, nil } // VerifyProof verifies over the BN128 the Pairings of the Proof -func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, printVer bool) bool { +func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, publicSignals []*big.Int, printVer bool) bool { // e(piA, Va) == e(piA', g2) pairingPiaVa := Utils.Bn.Pairing(proof.PiA, setup.Vk.Vka) pairingPiapG2 := Utils.Bn.Pairing(proof.PiAp, Utils.Bn.G2.G) @@ -269,14 +291,14 @@ func VerifyProof(circuit circuitcompiler.Circuit, setup Setup, proof Proof, prin } // Vkx, to then calculate Vkx+piA - vkxpia := setup.Vk.A[0] - for i := 0; i < len(proof.PublicSignals); i++ { - vkxpia = Utils.Bn.G1.Add(vkxpia, Utils.Bn.G1.MulScalar(setup.Vk.A[i+1], proof.PublicSignals[i])) + vkxpia := setup.Vk.IC[0] + for i := 0; i < len(publicSignals); i++ { + vkxpia = Utils.Bn.G1.Add(vkxpia, Utils.Bn.G1.MulScalar(setup.Vk.IC[i+1], publicSignals[i])) } // e(Vkx+piA, piB) == e(piH, Vkz) * e(piC, g2) if !Utils.Bn.Fq12.Equal( - Utils.Bn.Pairing(Utils.Bn.G1.Add(vkxpia, proof.PiA), proof.PiB), + Utils.Bn.Pairing(Utils.Bn.G1.Add(vkxpia, proof.PiA), proof.PiB), // TODO Add(vkxpia, proof.PiA) can go outside in order to save computation, as is reused later Utils.Bn.Fq12.Mul( Utils.Bn.Pairing(proof.PiH, setup.Vk.Vkz), Utils.Bn.Pairing(proof.PiC, Utils.Bn.G2.G))) { diff --git a/snark_test.go b/snark_test.go index ef53444..3f9ef7e 100644 --- a/snark_test.go +++ b/snark_test.go @@ -1,6 +1,7 @@ package snark import ( + "encoding/json" "fmt" "math/big" "strings" @@ -12,6 +13,76 @@ import ( "github.com/stretchr/testify/assert" ) +/* +func TestZkMultiplication(t *testing.T) { + + // compile circuit and get the R1CS + flatCode := ` + func test(a, b): + out = a * b + ` + + // parse the code + parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) + circuit, err := parser.Parse() + assert.Nil(t, err) + + b3 := big.NewInt(int64(3)) + b4 := big.NewInt(int64(4)) + inputs := []*big.Int{b3, b4} + // wittness + w, err := circuit.CalculateWitness(inputs) + assert.Nil(t, err) + + fmt.Println("circuit") + fmt.Println(circuit.NPublic) + + // flat code to R1CS + a, b, c := circuit.GenerateR1CS() + fmt.Println("\nR1CS:") + fmt.Println("a:", a) + fmt.Println("b:", b) + fmt.Println("c:", c) + + // R1CS to QAP + alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c) + fmt.Println("qap") + fmt.Println("alphas", alphas) + fmt.Println("betas", betas) + fmt.Println("gammas", gammas) + + ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) + + hx := Utils.PF.DivisorPolynomial(px, zx) + + // hx==px/zx so px==hx*zx + assert.Equal(t, px, Utils.PF.Mul(hx, zx)) + + // p(x) = a(x) * b(x) - c(x) == h(x) * z(x) + abc := Utils.PF.Sub(Utils.PF.Mul(ax, bx), cx) + assert.Equal(t, abc, px) + hz := Utils.PF.Mul(hx, zx) + assert.Equal(t, abc, hz) + + div, rem := Utils.PF.Div(px, zx) + assert.Equal(t, hx, div) + assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(1)) + + // calculate trusted setup + setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx) + assert.Nil(t, err) + + // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) + proof, err := GenerateProofs(*circuit, setup, hx, w) + assert.Nil(t, err) + + // assert.True(t, VerifyProof(*circuit, setup, proof, false)) + b35 := big.NewInt(int64(35)) + publicSignals := []*big.Int{b35} + assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true)) +} +*/ + func TestZkFromFlatCircuitCode(t *testing.T) { // compile circuit and get the R1CS @@ -30,11 +101,13 @@ func TestZkFromFlatCircuitCode(t *testing.T) { circuit, err := parser.Parse() assert.Nil(t, err) fmt.Println("\ncircuit data:", circuit) + circuitJson, _ := json.Marshal(circuit) + fmt.Println("circuit:", string(circuitJson)) b3 := big.NewInt(int64(3)) - inputs := []*big.Int{b3} + privateInputs := []*big.Int{b3} // wittness - w, err := circuit.CalculateWitness(inputs) + w, err := circuit.CalculateWitness(privateInputs) assert.Nil(t, err) fmt.Println("\nwitness", w) @@ -49,11 +122,16 @@ func TestZkFromFlatCircuitCode(t *testing.T) { // R1CS to QAP alphas, betas, gammas, zx := Utils.PF.R1CSToQAP(a, b, c) fmt.Println("qap") - fmt.Println(alphas) - fmt.Println(betas) - fmt.Println(gammas) + fmt.Println("alphas", alphas) + fmt.Println("betas", betas) + fmt.Println("gammas", gammas) + fmt.Println("zx", zx) ax, bx, cx, px := Utils.PF.CombinePolynomials(w, alphas, betas, gammas) + fmt.Println("ax", ax) + // fmt.Println("bx", bx) + // fmt.Println("cx", cx) + // fmt.Println("px", px) hx := Utils.PF.DivisorPolynomial(px, zx) @@ -72,21 +150,29 @@ func TestZkFromFlatCircuitCode(t *testing.T) { // calculate trusted setup setup, err := GenerateTrustedSetup(len(w), *circuit, alphas, betas, gammas, zx) + // setup, err := GenerateTrustedSetup(len(w), *circuit, ax, bx, cx, zx) assert.Nil(t, err) fmt.Println("\nt:", setup.Toxic.T) // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) proof, err := GenerateProofs(*circuit, setup, hx, w) assert.Nil(t, err) + fmt.Println("IC", setup.Vk.IC) - fmt.Println("\n proofs:") - fmt.Println(proof) - fmt.Println("public signals:", proof.PublicSignals) + // fmt.Println("\n proofs:") + // fmt.Println(proof) + + // fmt.Println("public signals:", proof.PublicSignals) + fmt.Println("\nwitness", w) + b35 := big.NewInt(int64(35)) + publicSignals := []*big.Int{b35} + fmt.Println("public signals:", publicSignals) before := time.Now() - assert.True(t, VerifyProof(*circuit, setup, proof, true)) + assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true)) fmt.Println("verify proof time elapsed:", time.Since(before)) } +/* func TestZkFromHardcodedR1CS(t *testing.T) { b0 := big.NewInt(int64(0)) b1 := big.NewInt(int64(1)) @@ -148,7 +234,9 @@ func TestZkFromHardcodedR1CS(t *testing.T) { proof, err := GenerateProofs(circuit, setup, hx, w) assert.Nil(t, err) - assert.True(t, VerifyProof(circuit, setup, proof, true)) + // assert.True(t, VerifyProof(circuit, setup, proof, true)) + publicSignals := []*big.Int{b35} + assert.True(t, VerifyProof(circuit, setup, proof, publicSignals, true)) } func TestZkMultiplication(t *testing.T) { @@ -202,5 +290,9 @@ func TestZkMultiplication(t *testing.T) { proof, err := GenerateProofs(*circuit, setup, hx, w) assert.Nil(t, err) - assert.True(t, VerifyProof(*circuit, setup, proof, false)) + // assert.True(t, VerifyProof(*circuit, setup, proof, false)) + b35 := big.NewInt(int64(35)) + publicSignals := []*big.Int{b35} + assert.True(t, VerifyProof(*circuit, setup, proof, publicSignals, true)) } +*/