refactor in sub packages

This commit is contained in:
arnaucube
2020-04-17 20:31:53 +02:00
parent 96b4b5b308
commit e1147db72f
14 changed files with 238 additions and 177 deletions

102
prover/arithmetic.go Normal file
View File

@@ -0,0 +1,102 @@
package prover
import (
"bytes"
"math/big"
)
func arrayOfZeroes(n int) []*big.Int {
var r []*big.Int
for i := 0; i < n; i++ {
r = append(r, new(big.Int).SetInt64(0))
}
return r
}
func fAdd(a, b *big.Int) *big.Int {
ab := new(big.Int).Add(a, b)
return new(big.Int).Mod(ab, R)
}
func fSub(a, b *big.Int) *big.Int {
ab := new(big.Int).Sub(a, b)
return new(big.Int).Mod(ab, R)
}
func fMul(a, b *big.Int) *big.Int {
ab := new(big.Int).Mul(a, b)
return new(big.Int).Mod(ab, R)
}
func fDiv(a, b *big.Int) *big.Int {
ab := new(big.Int).Mul(a, new(big.Int).ModInverse(b, R))
return new(big.Int).Mod(ab, R)
}
func fNeg(a *big.Int) *big.Int {
return new(big.Int).Mod(new(big.Int).Neg(a), R)
}
func fInv(a *big.Int) *big.Int {
return new(big.Int).ModInverse(a, R)
}
func fExp(base *big.Int, e *big.Int) *big.Int {
res := big.NewInt(1)
rem := new(big.Int).Set(e)
exp := base
for !bytes.Equal(rem.Bytes(), big.NewInt(int64(0)).Bytes()) {
// if BigIsOdd(rem) {
if rem.Bit(0) == 1 { // .Bit(0) returns 1 when is odd
res = fMul(res, exp)
}
exp = fMul(exp, exp)
rem = new(big.Int).Rsh(rem, 1)
}
return res
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func polynomialSub(a, b []*big.Int) []*big.Int {
r := arrayOfZeroes(max(len(a), len(b)))
for i := 0; i < len(a); i++ {
r[i] = fAdd(r[i], a[i])
}
for i := 0; i < len(b); i++ {
r[i] = fSub(r[i], b[i])
}
return r
}
func polynomialMul(a, b []*big.Int) []*big.Int {
r := arrayOfZeroes(len(a) + len(b) - 1)
for i := 0; i < len(a); i++ {
for j := 0; j < len(b); j++ {
r[i+j] = fAdd(r[i+j], fMul(a[i], b[j]))
}
}
return r
}
func polynomialDiv(a, b []*big.Int) ([]*big.Int, []*big.Int) {
// https://en.wikipedia.org/wiki/Division_algorithm
r := arrayOfZeroes(len(a) - len(b) + 1)
rem := a
for len(rem) >= len(b) {
l := fDiv(rem[len(rem)-1], b[len(b)-1])
pos := len(rem) - len(b)
r[pos] = l
aux := arrayOfZeroes(pos)
aux1 := append(aux, l)
aux2 := polynomialSub(rem, polynomialMul(b, aux1))
rem = aux2[:len(aux2)-1]
}
return r, rem
}

101
prover/ifft.go Normal file
View File

@@ -0,0 +1,101 @@
package prover
import (
"math"
"math/big"
)
type rootsT struct {
roots [][]*big.Int
w []*big.Int
}
func newRootsT() rootsT {
var roots rootsT
rem := new(big.Int).Sub(R, big.NewInt(1))
s := 0
for rem.Bit(0) == 0 { // rem.Bit==0 when even
s++
rem = new(big.Int).Rsh(rem, 1)
}
roots.w = make([]*big.Int, s+1)
roots.w[s] = fExp(big.NewInt(5), rem)
n := s - 1
for n >= 0 {
roots.w[n] = fMul(roots.w[n+1], roots.w[n+1])
n--
}
roots.roots = make([][]*big.Int, 50) // TODO WIP
roots.setRoots(15)
return roots
}
func (roots rootsT) setRoots(n int) {
for i := n; i >= 0 && nil == roots.roots[i]; i-- { // TODO tmp i<=len(r)
r := big.NewInt(1)
nroots := 1 << i
var rootsi []*big.Int
for j := 0; j < nroots; j++ {
rootsi = append(rootsi, r)
r = fMul(r, roots.w[i])
}
roots.roots[i] = rootsi
}
}
func fft(roots rootsT, pall []*big.Int, bits, offset, step int) []*big.Int {
n := 1 << bits
if n == 1 {
return []*big.Int{pall[offset]}
} else if n == 2 {
return []*big.Int{
fAdd(pall[offset], pall[offset+step]), // TODO tmp
fSub(pall[offset], pall[offset+step]),
}
}
ndiv2 := n >> 1
p1 := fft(roots, pall, bits-1, offset, step*2)
p2 := fft(roots, pall, bits-1, offset+step, step*2)
// var out []*big.Int
out := make([]*big.Int, n)
for i := 0; i < ndiv2; i++ {
// fmt.Println(i, len(roots.roots))
out[i] = fAdd(p1[i], fMul(roots.roots[bits][i], p2[i]))
out[i+ndiv2] = fSub(p1[i], fMul(roots.roots[bits][i], p2[i]))
}
return out
}
func ifft(p []*big.Int) []*big.Int {
if len(p) <= 1 {
return p
}
bits := math.Log2(float64(len(p)-1)) + 1
roots := newRootsT()
roots.setRoots(int(bits))
m := 1 << int(bits)
ep := extend(p, m)
res := fft(roots, ep, int(bits), 0, 1)
twoinvm := fInv(fMul(big.NewInt(1), big.NewInt(int64(m))))
var resn []*big.Int
for i := 0; i < m; i++ {
resn = append(resn, fMul(res[(m-i)%m], twoinvm))
}
return resn
}
func extend(p []*big.Int, e int) []*big.Int {
if e == len(p) {
return p
}
z := arrayOfZeroes(e - len(p))
return append(p, z...)
}

135
prover/prover.go Normal file
View File

@@ -0,0 +1,135 @@
package prover
import (
"crypto/rand"
"math/big"
bn256 "github.com/ethereum/go-ethereum/crypto/bn256/cloudflare"
"github.com/iden3/go-circom-prover-verifier/types"
)
// Proof is the data structure of the Groth16 zkSNARK proof
type Proof struct {
A *bn256.G1
B *bn256.G2
C *bn256.G1
}
// Pk holds the data structure of the ProvingKey
type Pk struct {
A []*bn256.G1
B2 []*bn256.G2
B1 []*bn256.G1
C []*bn256.G1
NVars int
NPublic int
VkAlpha1 *bn256.G1
VkDelta1 *bn256.G1
VkBeta1 *bn256.G1
VkBeta2 *bn256.G2
VkDelta2 *bn256.G2
HExps []*bn256.G1
DomainSize int
PolsA []map[int]*big.Int
PolsB []map[int]*big.Int
PolsC []map[int]*big.Int
}
// Witness contains the witness
type Witness []*big.Int
// R is the mod of the finite field
var R, _ = new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10)
func randBigInt() (*big.Int, error) {
maxbits := R.BitLen()
b := make([]byte, (maxbits/8)-1)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
r := new(big.Int).SetBytes(b)
rq := new(big.Int).Mod(r, R)
return rq, nil
}
// GenerateProof generates the Groth16 zkSNARK proof
func GenerateProof(pk *types.Pk, w types.Witness) (*types.Proof, []*big.Int, error) {
var proof types.Proof
r, err := randBigInt()
if err != nil {
return nil, nil, err
}
s, err := randBigInt()
if err != nil {
return nil, nil, err
}
proof.A = new(bn256.G1).ScalarBaseMult(big.NewInt(0))
proof.B = new(bn256.G2).ScalarBaseMult(big.NewInt(0))
proof.C = new(bn256.G1).ScalarBaseMult(big.NewInt(0))
proofBG1 := new(bn256.G1).ScalarBaseMult(big.NewInt(0))
for i := 0; i < pk.NVars; i++ {
proof.A = new(bn256.G1).Add(proof.A, new(bn256.G1).ScalarMult(pk.A[i], w[i]))
proof.B = new(bn256.G2).Add(proof.B, new(bn256.G2).ScalarMult(pk.B2[i], w[i]))
proofBG1 = new(bn256.G1).Add(proofBG1, new(bn256.G1).ScalarMult(pk.B1[i], w[i]))
}
for i := pk.NPublic + 1; i < pk.NVars; i++ {
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.C[i], w[i]))
}
proof.A = new(bn256.G1).Add(proof.A, pk.VkAlpha1)
proof.A = new(bn256.G1).Add(proof.A, new(bn256.G1).ScalarMult(pk.VkDelta1, r))
proof.B = new(bn256.G2).Add(proof.B, pk.VkBeta2)
proof.B = new(bn256.G2).Add(proof.B, new(bn256.G2).ScalarMult(pk.VkDelta2, s))
proofBG1 = new(bn256.G1).Add(proofBG1, pk.VkBeta1)
proofBG1 = new(bn256.G1).Add(proofBG1, new(bn256.G1).ScalarMult(pk.VkDelta1, s))
h := calculateH(pk, w)
for i := 0; i < len(h); i++ {
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.HExps[i], h[i]))
}
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(proof.A, s))
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(proofBG1, r))
rsneg := new(big.Int).Mod(new(big.Int).Neg(new(big.Int).Mul(r, s)), R) // fAdd & fMul
proof.C = new(bn256.G1).Add(proof.C, new(bn256.G1).ScalarMult(pk.VkDelta1, rsneg))
pubSignals := w[1 : pk.NPublic+1]
return &proof, pubSignals, nil
}
func calculateH(pk *types.Pk, w types.Witness) []*big.Int {
m := pk.DomainSize
polAT := arrayOfZeroes(m)
polBT := arrayOfZeroes(m)
polCT := arrayOfZeroes(m)
for i := 0; i < pk.NVars; i++ {
for j := range pk.PolsA[i] {
polAT[j] = fAdd(polAT[j], fMul(w[i], pk.PolsA[i][j]))
}
for j := range pk.PolsB[i] {
polBT[j] = fAdd(polBT[j], fMul(w[i], pk.PolsB[i][j]))
}
for j := range pk.PolsC[i] {
polCT[j] = fAdd(polCT[j], fMul(w[i], pk.PolsC[i][j]))
}
}
polAS := ifft(polAT)
polBS := ifft(polBT)
polABS := polynomialMul(polAS, polBS)
polCS := ifft(polCT)
polABCS := polynomialSub(polABS, polCS)
hS := polABCS[m:]
return hS
}

106
prover/prover_test.go Normal file
View File

@@ -0,0 +1,106 @@
package prover
import (
"encoding/json"
"io/ioutil"
"math/big"
"testing"
"github.com/iden3/go-circom-prover-verifier/parsers"
"github.com/iden3/go-circom-prover-verifier/types"
"github.com/iden3/go-circom-prover-verifier/verifier"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSmallCircuitGenerateProf(t *testing.T) {
provingKeyJson, err := ioutil.ReadFile("../testdata/small/proving_key.json")
require.Nil(t, err)
pk, err := parsers.ParsePk(provingKeyJson)
require.Nil(t, err)
witnessJson, err := ioutil.ReadFile("../testdata/small/witness.json")
require.Nil(t, err)
w, err := parsers.ParseWitness(witnessJson)
require.Nil(t, err)
assert.Equal(t, types.Witness{big.NewInt(1), big.NewInt(33), big.NewInt(3), big.NewInt(11)}, w)
proof, pubSignals, err := GenerateProof(pk, w)
assert.Nil(t, err)
proofStr, err := parsers.ProofToJson(proof)
assert.Nil(t, err)
err = ioutil.WriteFile("../testdata/small/proof.json", proofStr, 0644)
assert.Nil(t, err)
publicStr, err := json.Marshal(parsers.ArrayBigIntToString(pubSignals))
assert.Nil(t, err)
err = ioutil.WriteFile("../testdata/small/public.json", publicStr, 0644)
assert.Nil(t, err)
// verify the proof
vkJson, err := ioutil.ReadFile("../testdata/small/verification_key.json")
require.Nil(t, err)
vk, err := parsers.ParseVk(vkJson)
require.Nil(t, err)
v := verifier.Verify(vk, proof, pubSignals)
assert.True(t, v)
// to verify the proof with snarkjs:
// snarkjs verify --vk testdata/small/verification_key.json -p testdata/small/proof.json --pub testdata/small/public.json
}
func TestBigCircuitGenerateProf(t *testing.T) {
provingKeyJson, err := ioutil.ReadFile("../testdata/big/proving_key.json")
require.Nil(t, err)
pk, err := parsers.ParsePk(provingKeyJson)
require.Nil(t, err)
witnessJson, err := ioutil.ReadFile("../testdata/big/witness.json")
require.Nil(t, err)
w, err := parsers.ParseWitness(witnessJson)
require.Nil(t, err)
proof, pubSignals, err := GenerateProof(pk, w)
assert.Nil(t, err)
proofStr, err := parsers.ProofToJson(proof)
assert.Nil(t, err)
err = ioutil.WriteFile("../testdata/big/proof.json", proofStr, 0644)
assert.Nil(t, err)
publicStr, err := json.Marshal(parsers.ArrayBigIntToString(pubSignals))
assert.Nil(t, err)
err = ioutil.WriteFile("../testdata/big/public.json", publicStr, 0644)
assert.Nil(t, err)
// verify the proof
vkJson, err := ioutil.ReadFile("../testdata/big/verification_key.json")
require.Nil(t, err)
vk, err := parsers.ParseVk(vkJson)
require.Nil(t, err)
v := verifier.Verify(vk, proof, pubSignals)
assert.True(t, v)
// to verify the proof with snarkjs:
// snarkjs verify --vk testdata/big/verification_key.json -p testdata/big/proof.json --pub testdata/big/public.json
}
func BenchmarkGenerateProof(b *testing.B) {
provingKeyJson, err := ioutil.ReadFile("../testdata/big/proving_key.json")
require.Nil(b, err)
pk, err := parsers.ParsePk(provingKeyJson)
require.Nil(b, err)
witnessJson, err := ioutil.ReadFile("../testdata/big/witness.json")
require.Nil(b, err)
w, err := parsers.ParseWitness(witnessJson)
require.Nil(b, err)
for i := 0; i < b.N; i++ {
GenerateProof(pk, w)
}
}