Browse Source

Add paralelization of polynomials in GenerateProof

- before:
BenchmarkGenerateProof-4               1        1553842743 ns/op
For a circuit of 9094 constraints takes 7.761949512s seconds to generate the proof.

- now:
BenchmarkGenerateProof-4               1        1331576862 ns/op
For a circuit of 9094 constraints takes 5.745279126s to generate the proof.

For bigger circuits (more constraints) the difference will be bigger.

Executed on a Intel(R) Core(TM) i5-7200U CPU @ 2.50GHz, with 16GB of RAM
ed255-patch-1
arnaucube 4 years ago
parent
commit
3691785054
3 changed files with 58 additions and 22 deletions
  1. +1
    -0
      .gitignore
  2. +49
    -21
      prover/prover.go
  3. +8
    -1
      prover/prover_test.go

+ 1
- 0
.gitignore

@ -3,5 +3,6 @@ testdata/*/*.wasm
testdata/*/*.cpp testdata/*/*.cpp
testdata/*/*.sym testdata/*/*.sym
testdata/*/*.r1cs testdata/*/*.r1cs
testdata/*/*.sol
!testdata/*/input.json !testdata/*/input.json
cli/*.json cli/*.json

+ 49
- 21
prover/prover.go

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"math" "math"
"math/big" "math/big"
"sync"
bn256 "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare" bn256 "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare"
"github.com/iden3/go-circom-prover-verifier/types" "github.com/iden3/go-circom-prover-verifier/types"
@ -72,30 +73,57 @@ func GenerateProof(pk *types.Pk, w types.Witness) (*types.Proof, []*big.Int, err
proof.C = new(bn256.G1).ScalarBaseMult(big.NewInt(0)) proof.C = new(bn256.G1).ScalarBaseMult(big.NewInt(0))
proofBG1 := new(bn256.G1).ScalarBaseMult(big.NewInt(0)) proofBG1 := new(bn256.G1).ScalarBaseMult(big.NewInt(0))
for i := 0; i < pk.NVars; i++ {
proof.A = new(bn256.G1).Add(proof.A, new(bn256.G1).ScalarMult(pk.A[i], w[i]))
proof.B = new(bn256.G2).Add(proof.B, new(bn256.G2).ScalarMult(pk.B2[i], w[i]))
proofBG1 = new(bn256.G1).Add(proofBG1, new(bn256.G1).ScalarMult(pk.B1[i], w[i]))
}
for i := pk.NPublic + 1; i < pk.NVars; i++ {
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.C[i], w[i]))
}
proof.A = new(bn256.G1).Add(proof.A, pk.VkAlpha1)
proof.A = new(bn256.G1).Add(proof.A, new(bn256.G1).ScalarMult(pk.VkDelta1, r))
proof.B = new(bn256.G2).Add(proof.B, pk.VkBeta2)
proof.B = new(bn256.G2).Add(proof.B, new(bn256.G2).ScalarMult(pk.VkDelta2, s))
proofBG1 = new(bn256.G1).Add(proofBG1, pk.VkBeta1)
proofBG1 = new(bn256.G1).Add(proofBG1, new(bn256.G1).ScalarMult(pk.VkDelta1, s))
var waitGroup sync.WaitGroup
waitGroup.Add(4)
go func(wg *sync.WaitGroup) {
for i := 0; i < pk.NVars; i++ {
proof.A = new(bn256.G1).Add(proof.A, new(bn256.G1).ScalarMult(pk.A[i], w[i]))
}
wg.Done()
}(&waitGroup)
go func(wg *sync.WaitGroup) {
for i := 0; i < pk.NVars; i++ {
proof.B = new(bn256.G2).Add(proof.B, new(bn256.G2).ScalarMult(pk.B2[i], w[i]))
}
wg.Done()
}(&waitGroup)
go func(wg *sync.WaitGroup) {
for i := 0; i < pk.NVars; i++ {
proofBG1 = new(bn256.G1).Add(proofBG1, new(bn256.G1).ScalarMult(pk.B1[i], w[i]))
}
wg.Done()
}(&waitGroup)
go func(wg *sync.WaitGroup) {
for i := pk.NPublic + 1; i < pk.NVars; i++ {
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.C[i], w[i]))
}
wg.Done()
}(&waitGroup)
waitGroup.Wait()
h := calculateH(pk, w) h := calculateH(pk, w)
for i := 0; i < len(h); i++ {
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.HExps[i], h[i]))
}
var waitGroup2 sync.WaitGroup
waitGroup2.Add(2)
go func(wg *sync.WaitGroup) {
proof.A = new(bn256.G1).Add(proof.A, pk.VkAlpha1)
proof.A = new(bn256.G1).Add(proof.A, new(bn256.G1).ScalarMult(pk.VkDelta1, r))
proof.B = new(bn256.G2).Add(proof.B, pk.VkBeta2)
proof.B = new(bn256.G2).Add(proof.B, new(bn256.G2).ScalarMult(pk.VkDelta2, s))
proofBG1 = new(bn256.G1).Add(proofBG1, pk.VkBeta1)
proofBG1 = new(bn256.G1).Add(proofBG1, new(bn256.G1).ScalarMult(pk.VkDelta1, s))
wg.Done()
}(&waitGroup2)
go func(wg *sync.WaitGroup) {
for i := 0; i < len(h); i++ {
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.HExps[i], h[i]))
}
wg.Done()
}(&waitGroup2)
waitGroup2.Wait()
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(proof.A, s)) proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(proof.A, s))
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(proofBG1, r)) proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(proofBG1, r))
rsneg := new(big.Int).Mod(new(big.Int).Neg(new(big.Int).Mul(r, s)), types.R) // fAdd & fMul rsneg := new(big.Int).Mod(new(big.Int).Neg(new(big.Int).Mul(r, s)), types.R) // fAdd & fMul

+ 8
- 1
prover/prover_test.go

@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"testing" "testing"
"time"
"github.com/iden3/go-circom-prover-verifier/parsers" "github.com/iden3/go-circom-prover-verifier/parsers"
"github.com/iden3/go-circom-prover-verifier/types" "github.com/iden3/go-circom-prover-verifier/types"
@ -27,8 +28,10 @@ func TestSmallCircuitGenerateProof(t *testing.T) {
assert.Equal(t, types.Witness{big.NewInt(1), big.NewInt(33), big.NewInt(3), big.NewInt(11)}, w) assert.Equal(t, types.Witness{big.NewInt(1), big.NewInt(33), big.NewInt(3), big.NewInt(11)}, w)
beforeT := time.Now()
proof, pubSignals, err := GenerateProof(pk, w) proof, pubSignals, err := GenerateProof(pk, w)
assert.Nil(t, err) assert.Nil(t, err)
fmt.Println("proof generation time elapsed:", time.Since(beforeT))
proofStr, err := parsers.ProofToJson(proof) proofStr, err := parsers.ProofToJson(proof)
assert.Nil(t, err) assert.Nil(t, err)
@ -64,8 +67,10 @@ func TestBigCircuitGenerateProof(t *testing.T) {
w, err := parsers.ParseWitness(witnessJson) w, err := parsers.ParseWitness(witnessJson)
require.Nil(t, err) require.Nil(t, err)
beforeT := time.Now()
proof, pubSignals, err := GenerateProof(pk, w) proof, pubSignals, err := GenerateProof(pk, w)
assert.Nil(t, err) assert.Nil(t, err)
fmt.Println("proof generation time elapsed:", time.Since(beforeT))
proofStr, err := parsers.ProofToJson(proof) proofStr, err := parsers.ProofToJson(proof)
assert.Nil(t, err) assert.Nil(t, err)
@ -99,7 +104,7 @@ func TestIdStateCircuitGenerateProof(t *testing.T) {
// trustedsetup files (generated in // trustedsetup files (generated in
// https://github.com/iden3/go-zksnark-full-flow-example) // https://github.com/iden3/go-zksnark-full-flow-example)
if false { if false {
fmt.Println("TestIdStateCircuitGenerateProof activated")
fmt.Println("\nTestIdStateCircuitGenerateProof activated")
provingKeyJson, err := ioutil.ReadFile("../testdata/idstate-circuit/proving_key.json") provingKeyJson, err := ioutil.ReadFile("../testdata/idstate-circuit/proving_key.json")
require.Nil(t, err) require.Nil(t, err)
pk, err := parsers.ParsePk(provingKeyJson) pk, err := parsers.ParsePk(provingKeyJson)
@ -110,8 +115,10 @@ func TestIdStateCircuitGenerateProof(t *testing.T) {
w, err := parsers.ParseWitness(witnessJson) w, err := parsers.ParseWitness(witnessJson)
require.Nil(t, err) require.Nil(t, err)
beforeT := time.Now()
proof, pubSignals, err := GenerateProof(pk, w) proof, pubSignals, err := GenerateProof(pk, w)
assert.Nil(t, err) assert.Nil(t, err)
fmt.Println("proof generation time elapsed:", time.Since(beforeT))
proofStr, err := parsers.ProofToJson(proof) proofStr, err := parsers.ProofToJson(proof)
assert.Nil(t, err) assert.Nil(t, err)

Loading…
Cancel
Save