diff --git a/parsers/parsers.go b/parsers/parsers.go index 0dfaf97..d8294c2 100644 --- a/parsers/parsers.go +++ b/parsers/parsers.go @@ -3,6 +3,7 @@ package parsers import ( "bufio" "bytes" + "encoding/binary" "encoding/hex" "encoding/json" "fmt" @@ -599,3 +600,368 @@ func swapEndianness(b []byte) []byte { } return o } + +func readNBytes(r io.Reader, n int) ([]byte, error) { + b := make([]byte, n) + _, err := io.ReadFull(r, b) + if err != nil { + return b, err + } + return b, nil +} + +// ParsePkBin parses binary file representation of the ProvingKey into the ProvingKey struct +func ParsePkBin(f *os.File) (*types.Pk, error) { + o := 0 + var pk types.Pk + r := bufio.NewReader(f) + + b, err := readNBytes(r, 12) + if err != nil { + return nil, err + } + pk.NVars = int(binary.LittleEndian.Uint32(b[:4])) + pk.NPublic = int(binary.LittleEndian.Uint32(b[4:8])) + pk.DomainSize = int(binary.LittleEndian.Uint32(b[8:12])) + o += 12 + + b, err = readNBytes(r, 8) + if err != nil { + return nil, err + } + pPolsA := int(binary.LittleEndian.Uint32(b[:4])) + pPolsB := int(binary.LittleEndian.Uint32(b[4:8])) + o += 8 + + b, err = readNBytes(r, 20) + if err != nil { + return nil, err + } + pPointsA := int(binary.LittleEndian.Uint32(b[:4])) + pPointsB1 := int(binary.LittleEndian.Uint32(b[4:8])) + pPointsB2 := int(binary.LittleEndian.Uint32(b[8:12])) + pPointsC := int(binary.LittleEndian.Uint32(b[12:16])) + pPointsHExps := int(binary.LittleEndian.Uint32(b[16:20])) + o += 20 + + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + pk.VkAlpha1 = new(bn256.G1) + _, err = pk.VkAlpha1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + pk.VkBeta1 = new(bn256.G1) + _, err = pk.VkBeta1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + pk.VkDelta1 = new(bn256.G1) + _, err = pk.VkDelta1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + b, err = readNBytes(r, 128) + if err != nil { + return nil, err + } + pk.VkBeta2 = new(bn256.G2) + _, err = pk.VkBeta2.Unmarshal(fromMont2Q(b)) + if err != nil { + return nil, err + } + b, err = readNBytes(r, 128) + if err != nil { + return nil, err + } + pk.VkDelta2 = new(bn256.G2) + _, err = pk.VkDelta2.Unmarshal(fromMont2Q(b)) + if err != nil { + return nil, err + } + o += 448 + if o != pPolsA { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPolsA, o) + } + + // PolsA + for i := 0; i < pk.NVars; i++ { + b, err = readNBytes(r, 4) + if err != nil { + return nil, err + } + keysLength := int(binary.LittleEndian.Uint32(b[:4])) + o += 4 + polsMap := make(map[int]*big.Int) + for j := 0; j < keysLength; j++ { + bK, err := readNBytes(r, 4) + if err != nil { + return nil, err + } + key := int(binary.LittleEndian.Uint32(bK[:4])) + o += 4 + + b, err := readNBytes(r, 32) + if err != nil { + return nil, err + } + polsMap[key] = new(big.Int).SetBytes(fromMont1R(b[:32])) + o += 32 + } + pk.PolsA = append(pk.PolsA, polsMap) + } + if o != pPolsB { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPolsB, o) + } + // PolsB + for i := 0; i < pk.NVars; i++ { + b, err = readNBytes(r, 4) + if err != nil { + return nil, err + } + keysLength := int(binary.LittleEndian.Uint32(b[:4])) + o += 4 + polsMap := make(map[int]*big.Int) + for j := 0; j < keysLength; j++ { + bK, err := readNBytes(r, 4) + if err != nil { + return nil, err + } + key := int(binary.LittleEndian.Uint32(bK[:4])) + o += 4 + + b, err := readNBytes(r, 32) + if err != nil { + return nil, err + } + polsMap[key] = new(big.Int).SetBytes(fromMont1R(b[:32])) + o += 32 + } + pk.PolsB = append(pk.PolsB, polsMap) + } + if o != pPointsA { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPointsA, o) + } + // A + for i := 0; i < pk.NVars; i++ { + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + p1 := new(bn256.G1) + _, err = p1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + pk.A = append(pk.A, p1) + o += 64 + } + if o != pPointsB1 { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPointsB1, o) + } + // B1 + for i := 0; i < pk.NVars; i++ { + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + p1 := new(bn256.G1) + _, err = p1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + pk.B1 = append(pk.B1, p1) + o += 64 + } + if o != pPointsB2 { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPointsB2, o) + } + // B2 + for i := 0; i < pk.NVars; i++ { + b, err = readNBytes(r, 128) + if err != nil { + return nil, err + } + p2 := new(bn256.G2) + _, err = p2.Unmarshal(fromMont2Q(b)) + if err != nil { + return nil, err + } + pk.B2 = append(pk.B2, p2) + o += 128 + } + if o != pPointsC { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPointsC, o) + } + // C + zb := make([]byte, 64) + z := new(bn256.G1) + _, err = z.Unmarshal(zb) + if err != nil { + return nil, err + } + pk.C = append(pk.C, z) // circom behaviour (3x null==["0", "0", "0"]) + pk.C = append(pk.C, z) + pk.C = append(pk.C, z) + for i := pk.NPublic + 1; i < pk.NVars; i++ { + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + p1 := new(bn256.G1) + _, err = p1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + pk.C = append(pk.C, p1) + o += 64 + } + if o != pPointsHExps { + return nil, fmt.Errorf("Unexpected offset, expected: %v, actual: %v", pPointsHExps, o) + } + for i := 0; i < pk.DomainSize; i++ { + b, err = readNBytes(r, 64) + if err != nil { + return nil, err + } + p1 := new(bn256.G1) + _, err = p1.Unmarshal(fromMont1Q(b)) + if err != nil { + return nil, err + } + pk.HExps = append(pk.HExps, p1) + } + return &pk, nil +} + +func fromMont1Q(m []byte) []byte { + a := new(big.Int).SetBytes(swapEndianness(m[:32])) + b := new(big.Int).SetBytes(swapEndianness(m[32:64])) + + x := coordFromMont(a, types.Q) + y := coordFromMont(b, types.Q) + if bytes.Equal(x.Bytes(), big.NewInt(1).Bytes()) { + x = big.NewInt(0) + } + if bytes.Equal(y.Bytes(), big.NewInt(1).Bytes()) { + y = big.NewInt(0) + } + + xBytes := x.Bytes() + yBytes := y.Bytes() + if len(xBytes) != 32 { + xBytes = addZPadding(xBytes) + } + if len(yBytes) != 32 { + yBytes = addZPadding(yBytes) + } + + var p []byte + p = append(p, xBytes...) + p = append(p, yBytes...) + + return p +} + +func fromMont2Q(m []byte) []byte { + a := new(big.Int).SetBytes(swapEndianness(m[:32])) + b := new(big.Int).SetBytes(swapEndianness(m[32:64])) + c := new(big.Int).SetBytes(swapEndianness(m[64:96])) + d := new(big.Int).SetBytes(swapEndianness(m[96:128])) + + x := coordFromMont(a, types.Q) + y := coordFromMont(b, types.Q) + z := coordFromMont(c, types.Q) + t := coordFromMont(d, types.Q) + + if bytes.Equal(x.Bytes(), big.NewInt(1).Bytes()) { + x = big.NewInt(0) + } + if bytes.Equal(y.Bytes(), big.NewInt(1).Bytes()) { + y = big.NewInt(0) + } + if bytes.Equal(z.Bytes(), big.NewInt(1).Bytes()) { + z = big.NewInt(0) + } + if bytes.Equal(t.Bytes(), big.NewInt(1).Bytes()) { + t = big.NewInt(0) + } + + xBytes := x.Bytes() + yBytes := y.Bytes() + zBytes := z.Bytes() + tBytes := t.Bytes() + if len(xBytes) != 32 { + xBytes = addZPadding(xBytes) + } + if len(yBytes) != 32 { + yBytes = addZPadding(yBytes) + } + if len(zBytes) != 32 { + zBytes = addZPadding(zBytes) + } + if len(tBytes) != 32 { + tBytes = addZPadding(tBytes) + } + + var p []byte + p = append(p, yBytes...) // swap + p = append(p, xBytes...) + p = append(p, tBytes...) + p = append(p, zBytes...) + + return p +} + +func fromMont1R(m []byte) []byte { + a := new(big.Int).SetBytes(swapEndianness(m[:32])) + + x := coordFromMont(a, types.R) + + return x.Bytes() +} + +func fromMont2R(m []byte) []byte { + a := new(big.Int).SetBytes(swapEndianness(m[:32])) + b := new(big.Int).SetBytes(swapEndianness(m[32:64])) + c := new(big.Int).SetBytes(swapEndianness(m[64:96])) + d := new(big.Int).SetBytes(swapEndianness(m[96:128])) + + x := coordFromMont(a, types.R) + y := coordFromMont(b, types.R) + z := coordFromMont(c, types.R) + t := coordFromMont(d, types.R) + + var p []byte + p = append(p, y.Bytes()...) // swap + p = append(p, x.Bytes()...) + p = append(p, t.Bytes()...) + p = append(p, z.Bytes()...) + + return p +} + +func coordFromMont(u, q *big.Int) *big.Int { + return new(big.Int).Mod( + new(big.Int).Mul( + u, + new(big.Int).ModInverse( + new(big.Int).Lsh(big.NewInt(1), 256), + q, + ), + ), + q, + ) +} diff --git a/parsers/parsers_test.go b/parsers/parsers_test.go index 82b6dea..d1d4db8 100644 --- a/parsers/parsers_test.go +++ b/parsers/parsers_test.go @@ -196,3 +196,63 @@ func TestProofSmartContractFormat(t *testing.T) { assert.Equal(t, pSC, pSC2) } + +func testCircuitParsePkBin(t *testing.T, circuit string) { + pkBinFile, err := os.Open("../testdata/" + circuit + "/proving_key.bin") + require.Nil(t, err) + defer pkBinFile.Close() + pk, err := ParsePkBin(pkBinFile) + require.Nil(t, err) + + pkJson, err := ioutil.ReadFile("../testdata/" + circuit + "/proving_key.json") + require.Nil(t, err) + pkJ, err := ParsePk(pkJson) + require.Nil(t, err) + + assert.Equal(t, pkJ.NVars, pk.NVars) + assert.Equal(t, pkJ.NPublic, pk.NPublic) + assert.Equal(t, pkJ.DomainSize, pk.DomainSize) + assert.Equal(t, pkJ.VkAlpha1, pk.VkAlpha1) + assert.Equal(t, pkJ.VkBeta1, pk.VkBeta1) + assert.Equal(t, pkJ.VkDelta1, pk.VkDelta1) + assert.Equal(t, pkJ.VkDelta2, pk.VkDelta2) + assert.Equal(t, pkJ.PolsA, pk.PolsA) + assert.Equal(t, pkJ.PolsB, pk.PolsB) + assert.Equal(t, pkJ.A, pk.A) + assert.Equal(t, pkJ.B1, pk.B1) + assert.Equal(t, pkJ.B2, pk.B2) + assert.Equal(t, pkJ.C, pk.C) + assert.Equal(t, pkJ.HExps[:pkJ.DomainSize], pk.HExps[:pk.DomainSize]) // circom behaviour +} + +func TestParsePkBin(t *testing.T) { + testCircuitParsePkBin(t, "circuit1k") + testCircuitParsePkBin(t, "circuit5k") +} + +func benchmarkParsePk(b *testing.B, circuit string) { + pkJson, err := ioutil.ReadFile("../testdata/" + circuit + "/proving_key.json") + require.Nil(b, err) + + pkBinFile, err := os.Open("../testdata/" + circuit + "/proving_key.bin") + require.Nil(b, err) + defer pkBinFile.Close() + + b.Run("Parse Pk bin "+circuit, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ParsePk(pkJson) + } + }) + b.Run("Parse Pk json "+circuit, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ParsePkBin(pkBinFile) + } + }) +} + +func BenchmarkParsePk(b *testing.B) { + benchmarkParsePk(b, "circuit1k") + benchmarkParsePk(b, "circuit5k") + // benchmarkParsePk(b, "circuit10k") + // benchmarkParsePk(b, "circuit20k") +} diff --git a/prover/prover_test.go b/prover/prover_test.go index ea7350a..e7fed45 100644 --- a/prover/prover_test.go +++ b/prover/prover_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "os" "testing" "time" @@ -16,25 +17,38 @@ import ( func TestCircuitsGenerateProof(t *testing.T) { testCircuitGenerateProof(t, "circuit1k") // 1000 constraints testCircuitGenerateProof(t, "circuit5k") // 5000 constraints - //testCircuitGenerateProof(t, "circuit10k") // 10000 constraints - //testCircuitGenerateProof(t, "circuit20k") // 20000 constraints + // testCircuitGenerateProof(t, "circuit10k") // 10000 constraints + // testCircuitGenerateProof(t, "circuit20k") // 20000 constraints } func testCircuitGenerateProof(t *testing.T, circuit string) { - provingKeyJson, err := ioutil.ReadFile("../testdata/" + circuit + "/proving_key.json") + // using json provingKey file + // provingKeyJson, err := ioutil.ReadFile("../testdata/" + circuit + "/proving_key.json") + // require.Nil(t, err) + // pk, err := parsers.ParsePk(provingKeyJson) + // require.Nil(t, err) + // witnessJson, err := ioutil.ReadFile("../testdata/" + circuit + "/witness.json") + // require.Nil(t, err) + // w, err := parsers.ParseWitness(witnessJson) + // require.Nil(t, err) + + // using bin provingKey file + pkBinFile, err := os.Open("../testdata/" + circuit + "/proving_key.bin") require.Nil(t, err) - pk, err := parsers.ParsePk(provingKeyJson) + defer pkBinFile.Close() + pk, err := parsers.ParsePkBin(pkBinFile) require.Nil(t, err) - witnessJson, err := ioutil.ReadFile("../testdata/" + circuit + "/witness.json") + witnessBinFile, err := os.Open("../testdata/" + circuit + "/witness.bin") require.Nil(t, err) - w, err := parsers.ParseWitness(witnessJson) + defer witnessBinFile.Close() + w, err := parsers.ParseWitnessBin(witnessBinFile) require.Nil(t, err) beforeT := time.Now() proof, pubSignals, err := GenerateProof(pk, w) assert.Nil(t, err) - fmt.Println("proof generation time elapsed:", time.Since(beforeT)) + fmt.Println("proof generation time for "+circuit+" elapsed:", time.Since(beforeT)) proofStr, err := parsers.ProofToJson(proof) assert.Nil(t, err) diff --git a/types/types.go b/types/types.go index 8277579..4e4e8f2 100644 --- a/types/types.go +++ b/types/types.go @@ -6,6 +6,8 @@ import ( bn256 "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare" ) +var Q, _ = new(big.Int).SetString("21888242871839275222246405745257275088696311157297823662689037894645226208583", 10) + // R is the mod of the finite field var R, _ = new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10)