From aefb298bb007a3e804bafa12d6cf6bb6d9140038 Mon Sep 17 00:00:00 2001 From: arnaucube Date: Fri, 28 Dec 2018 00:46:42 +0100 Subject: [PATCH] circuit CalculateWitness, added - & / in GenerateR1CS(), added doc --- README.md | 21 +++++- bn128/bn128.go | 44 ++++++----- bn128/bn128_test.go | 10 +-- circuitcompiler/circuit.go | 60 ++++++++++++++- circuitcompiler/circuit_test.go | 6 ++ circuitcompiler/lexer.go | 5 +- circuitcompiler/parser.go | 10 ++- r1csqap/r1csqap.go | 16 ++++ snark.go | 5 ++ snark_test.go | 129 +++++++++++++++++--------------- 10 files changed, 211 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index 52f2415..9ba5773 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,20 @@ Implementation from scratch in Go to understand the concepts. Do not use in prod Not finished, implementing this in my free time to understand it better, so I don't have much time. +Current implementation status: +- [x] Finite Fields (1, 2, 6, 12) operations +- [x] G1 and G2 operations +- [x] BN128 Pairing +- [x] circuit code compiler + - [ ] code to flat code + - [x] flat code compiler +- [x] circuit to R1CS +- [x] polynomial operations +- [x] R1CS to QAP +- [x] generate trusted setup +- [x] generate proofs +- [x] verify proofs with BN128 pairing + ### Usage - [![GoDoc](https://godoc.org/github.com/arnaucube/go-snark?status.svg)](https://godoc.org/github.com/arnaucube/go-snark) zkSnark @@ -57,8 +71,11 @@ c == [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]] alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) -// wittness = 1, 3, 35, 9, 27, 30 -w := []*big.Int{b1, b3, b35, b9, b27, b30} +// wittness +b3 := big.NewInt(int64(3)) +inputs := []*big.Int{b3} +w := circuit.CalculateWitness(inputs) +fmt.Println("\nwitness", w) ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas) diff --git a/bn128/bn128.go b/bn128/bn128.go index 8af445b..ebef4ee 100644 --- a/bn128/bn128.go +++ b/bn128/bn128.go @@ -7,6 +7,7 @@ import ( "github.com/arnaucube/go-snark/fields" ) +// Bn128 is the data structure of the BN128 type Bn128 struct { Q *big.Int R *big.Int @@ -33,6 +34,7 @@ type Bn128 struct { FinalExp *big.Int } +// NewBn128 returns the BN128 func NewBn128() (Bn128, error) { var b Bn128 q, ok := new(big.Int).SetString("21888242871839275222246405745257275088696311157297823662689037894645226208583", 10) @@ -105,6 +107,7 @@ func NewBn128() (Bn128, error) { return b, nil } +// NewFqR returns a new Finite Field over R func NewFqR() (fields.Fq, error) { r, ok := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) if !ok { @@ -172,12 +175,13 @@ func (bn128 *Bn128) preparePairing() error { } +// Pairing calculates the BN128 Pairing of two given values func (bn128 Bn128) Pairing(p1 [3]*big.Int, p2 [3][2]*big.Int) [2][3][2]*big.Int { - pre1 := bn128.PreComputeG1(p1) - pre2 := bn128.PreComputeG2(p2) + pre1 := bn128.preComputeG1(p1) + pre2 := bn128.preComputeG2(p2) r1 := bn128.MillerLoop(pre1, pre2) - res := bn128.FinalExponentiation(r1) + res := bn128.finalExponentiation(r1) return res } @@ -186,7 +190,7 @@ type AteG1Precomp struct { Py *big.Int } -func (bn128 Bn128) PreComputeG1(p [3]*big.Int) AteG1Precomp { +func (bn128 Bn128) preComputeG1(p [3]*big.Int) AteG1Precomp { pCopy := bn128.G1.Affine(p) res := AteG1Precomp{ Px: pCopy[0], @@ -206,7 +210,7 @@ type AteG2Precomp struct { Coeffs []EllCoeffs } -func (bn128 Bn128) PreComputeG2(p [3][2]*big.Int) AteG2Precomp { +func (bn128 Bn128) preComputeG2(p [3][2]*big.Int) AteG2Precomp { qCopy := bn128.G2.Affine(p) res := AteG2Precomp{ qCopy[0], @@ -222,20 +226,20 @@ func (bn128 Bn128) PreComputeG2(p [3][2]*big.Int) AteG2Precomp { for i := bn128.LoopCount.BitLen() - 2; i >= 0; i-- { bit := bn128.LoopCount.Bit(i) - c, r = bn128.DoublingStep(r) + c, r = bn128.doublingStep(r) res.Coeffs = append(res.Coeffs, c) if bit == 1 { - c, r = bn128.MixedAdditionStep(qCopy, r) + c, r = bn128.mixedAdditionStep(qCopy, r) res.Coeffs = append(res.Coeffs, c) } } - q1 := bn128.G2.Affine(bn128.G2MulByQ(qCopy)) + q1 := bn128.G2.Affine(bn128.g2MulByQ(qCopy)) if !bn128.Fq2.Equal(q1[2], bn128.Fq2.One()) { // return res, errors.New("q1[2] != Fq2.One") panic(errors.New("q1[2] != Fq2.One()")) } - q2 := bn128.G2.Affine(bn128.G2MulByQ(q1)) + q2 := bn128.G2.Affine(bn128.g2MulByQ(q1)) if !bn128.Fq2.Equal(q2[2], bn128.Fq2.One()) { // return res, errors.New("q2[2] != Fq2.One") panic(errors.New("q2[2] != Fq2.One()")) @@ -246,16 +250,16 @@ func (bn128 Bn128) PreComputeG2(p [3][2]*big.Int) AteG2Precomp { } q2[1] = bn128.Fq2.Neg(q2[1]) - c, r = bn128.MixedAdditionStep(q1, r) + c, r = bn128.mixedAdditionStep(q1, r) res.Coeffs = append(res.Coeffs, c) - c, r = bn128.MixedAdditionStep(q2, r) + c, r = bn128.mixedAdditionStep(q2, r) res.Coeffs = append(res.Coeffs, c) return res } -func (bn128 Bn128) DoublingStep(current [3][2]*big.Int) (EllCoeffs, [3][2]*big.Int) { +func (bn128 Bn128) doublingStep(current [3][2]*big.Int) (EllCoeffs, [3][2]*big.Int) { x := current[0] y := current[1] z := current[2] @@ -286,7 +290,7 @@ func (bn128 Bn128) DoublingStep(current [3][2]*big.Int) (EllCoeffs, [3][2]*big.I return res, current } -func (bn128 Bn128) MixedAdditionStep(base, current [3][2]*big.Int) (EllCoeffs, [3][2]*big.Int) { +func (bn128 Bn128) mixedAdditionStep(base, current [3][2]*big.Int) (EllCoeffs, [3][2]*big.Int) { x1 := current[0] y1 := current[1] z1 := current[2] @@ -320,7 +324,7 @@ func (bn128 Bn128) MixedAdditionStep(base, current [3][2]*big.Int) (EllCoeffs, [ } return coef, current } -func (bn128 Bn128) G2MulByQ(p [3][2]*big.Int) [3][2]*big.Int { +func (bn128 Bn128) g2MulByQ(p [3][2]*big.Int) [3][2]*big.Int { fmx := [2]*big.Int{ p[0][0], bn128.Fq1.Mul(p[0][1], bn128.Fq1.Copy(bn128.FrobeniusCoeffsC11)), @@ -356,7 +360,7 @@ func (bn128 Bn128) MillerLoop(pre1 AteG1Precomp, pre2 AteG2Precomp) [2][3][2]*bi idx++ f = bn128.Fq12.Square(f) - f = bn128.MulBy024(f, + f = bn128.mulBy024(f, c.Ell0, bn128.Fq2.MulScalar(c.EllVW, pre1.Py), bn128.Fq2.MulScalar(c.EllVV, pre1.Px)) @@ -364,7 +368,7 @@ func (bn128 Bn128) MillerLoop(pre1 AteG1Precomp, pre2 AteG2Precomp) [2][3][2]*bi if bit == 1 { c = pre2.Coeffs[idx] idx++ - f = bn128.MulBy024( + f = bn128.mulBy024( f, c.Ell0, bn128.Fq2.MulScalar(c.EllVW, pre1.Py), @@ -377,7 +381,7 @@ func (bn128 Bn128) MillerLoop(pre1 AteG1Precomp, pre2 AteG2Precomp) [2][3][2]*bi c = pre2.Coeffs[idx] idx++ - f = bn128.MulBy024( + f = bn128.mulBy024( f, c.Ell0, bn128.Fq2.MulScalar(c.EllVW, pre1.Py), @@ -386,7 +390,7 @@ func (bn128 Bn128) MillerLoop(pre1 AteG1Precomp, pre2 AteG2Precomp) [2][3][2]*bi c = pre2.Coeffs[idx] idx++ - f = bn128.MulBy024( + f = bn128.mulBy024( f, c.Ell0, bn128.Fq2.MulScalar(c.EllVW, pre1.Py), @@ -395,7 +399,7 @@ func (bn128 Bn128) MillerLoop(pre1 AteG1Precomp, pre2 AteG2Precomp) [2][3][2]*bi return f } -func (bn128 Bn128) MulBy024(a [2][3][2]*big.Int, ell0, ellVW, ellVV [2]*big.Int) [2][3][2]*big.Int { +func (bn128 Bn128) mulBy024(a [2][3][2]*big.Int, ell0, ellVW, ellVV [2]*big.Int) [2][3][2]*big.Int { b := [2][3][2]*big.Int{ [3][2]*big.Int{ ell0, @@ -411,7 +415,7 @@ func (bn128 Bn128) MulBy024(a [2][3][2]*big.Int, ell0, ellVW, ellVV [2]*big.Int) return bn128.Fq12.Mul(a, b) } -func (bn128 Bn128) FinalExponentiation(r [2][3][2]*big.Int) [2][3][2]*big.Int { +func (bn128 Bn128) finalExponentiation(r [2][3][2]*big.Int) [2][3][2]*big.Int { res := bn128.Fq12.Exp(r, bn128.FinalExp) return res } diff --git a/bn128/bn128_test.go b/bn128/bn128_test.go index ecfa585..0abd7b2 100644 --- a/bn128/bn128_test.go +++ b/bn128/bn128_test.go @@ -21,11 +21,11 @@ func TestBN128(t *testing.T) { g1b := bn128.G1.MulScalar(bn128.G1.G, bn128.Fq1.Copy(big75)) g2b := bn128.G2.MulScalar(bn128.G2.G, bn128.Fq1.Copy(big40)) - pre1a := bn128.PreComputeG1(g1a) - pre2a := bn128.PreComputeG2(g2a) + pre1a := bn128.preComputeG1(g1a) + pre2a := bn128.preComputeG2(g2a) assert.Nil(t, err) - pre1b := bn128.PreComputeG1(g1b) - pre2b := bn128.PreComputeG2(g2b) + pre1b := bn128.preComputeG1(g1b) + pre2b := bn128.preComputeG2(g2b) assert.Nil(t, err) r1 := bn128.MillerLoop(pre1a, pre2a) @@ -33,7 +33,7 @@ func TestBN128(t *testing.T) { rbe := bn128.Fq12.Mul(r1, bn128.Fq12.Inverse(r2)) - res := bn128.FinalExponentiation(rbe) + res := bn128.finalExponentiation(rbe) a := bn128.Fq12.Affine(res) b := bn128.Fq12.Affine(bn128.Fq12.One()) diff --git a/circuitcompiler/circuit.go b/circuitcompiler/circuit.go index 9143551..f3f6c3b 100644 --- a/circuitcompiler/circuit.go +++ b/circuitcompiler/circuit.go @@ -8,6 +8,7 @@ import ( "github.com/arnaucube/go-snark/r1csqap" ) +// Circuit is the data structure of the compiled circuit type Circuit struct { NVars int NPublic int @@ -22,6 +23,8 @@ type Circuit struct { C [][]*big.Int } } + +// Constraint is the data structure of a flat code operation type Constraint struct { // v1 op v2 = out Op string @@ -61,7 +64,21 @@ func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) } return arr, used } +func insertVarNeg(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) { + isVal, value := isValue(v) + valueBigInt := big.NewInt(int64(value)) + if isVal { + arr[0] = new(big.Int).Add(arr[0], valueBigInt) + } else { + if !used[v] { + panic(errors.New("using variable before it's set")) + } + arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(-1))) + } + return arr, used +} +// GenerateR1CS generates the R1CS polynomials from the Circuit func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { // from flat code to R1CS @@ -71,7 +88,6 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { used := make(map[string]bool) for _, constraint := range circ.Constraints { - aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals)) bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals)) cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals)) @@ -86,7 +102,6 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1))) aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used) bConstraint[0] = big.NewInt(int64(1)) - } continue @@ -95,10 +110,19 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used) aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used) bConstraint[0] = big.NewInt(int64(1)) + } else if constraint.Op == "-" { + cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1)) + aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V1, used) + aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V2, used) + bConstraint[0] = big.NewInt(int64(1)) } else if constraint.Op == "*" { cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1)) aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used) bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used) + } else if constraint.Op == "/" { + cConstraint, used = insertVar(cConstraint, circ.Signals, constraint.V1, used) + cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1)) + bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used) } a = append(a, aConstraint) @@ -108,3 +132,35 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) { } return a, b, c } + +func grabVar(signals []string, w []*big.Int, vStr string) *big.Int { + isVal, v := isValue(vStr) + vBig := big.NewInt(int64(v)) + if isVal { + return vBig + } else { + return w[indexInArray(signals, vStr)] + } +} + +// CalculateWitness calculates the Witness of a Circuit based on the given inputs +func (circ *Circuit) CalculateWitness(inputs []*big.Int) []*big.Int { + w := r1csqap.ArrayOfBigZeros(len(circ.Signals)) + w[0] = big.NewInt(int64(1)) + for i, input := range inputs { + w[i+1] = input + } + for _, constraint := range circ.Constraints { + if constraint.Op == "in" { + } else if constraint.Op == "+" { + w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2)) + } else if constraint.Op == "-" { + w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Sub(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2)) + } else if constraint.Op == "*" { + w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Mul(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2)) + } else if constraint.Op == "/" { + w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Div(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2)) + } + } + return w +} diff --git a/circuitcompiler/circuit_test.go b/circuitcompiler/circuit_test.go index 9cb183a..c1912ca 100644 --- a/circuitcompiler/circuit_test.go +++ b/circuitcompiler/circuit_test.go @@ -74,4 +74,10 @@ func TestCircuitParser(t *testing.T) { fmt.Println(a) fmt.Println(b) fmt.Println(c) + + b3 := big.NewInt(int64(3)) + inputs := []*big.Int{b3} + // Calculate Witness + w := circuit.CalculateWitness(inputs) + fmt.Println("w", w) } diff --git a/circuitcompiler/lexer.go b/circuitcompiler/lexer.go index 75be8cd..92bc45c 100644 --- a/circuitcompiler/lexer.go +++ b/circuitcompiler/lexer.go @@ -42,10 +42,12 @@ func isDigit(ch rune) bool { return (ch >= '0' && ch <= '9') } +// Scanner holds the bufio.Reader type Scanner struct { r *bufio.Reader } +// NewScanner creates a new Scanner with the given io.Reader func NewScanner(r io.Reader) *Scanner { return &Scanner{r: bufio.NewReader(r)} } @@ -62,7 +64,8 @@ func (s *Scanner) unread() { _ = s.r.UnreadRune() } -func (s *Scanner) Scan() (tok Token, lit string) { +// Scan returns the Token and literal string of the current value +func (s *Scanner) scan() (tok Token, lit string) { ch := s.read() if isWhitespace(ch) { diff --git a/circuitcompiler/parser.go b/circuitcompiler/parser.go index cc94700..584d511 100644 --- a/circuitcompiler/parser.go +++ b/circuitcompiler/parser.go @@ -7,6 +7,7 @@ import ( "strings" ) +// Parser data structure holds the Scanner and the Parsing functions type Parser struct { s *Scanner buf struct { @@ -16,6 +17,7 @@ type Parser struct { } } +// NewParser creates a new parser from a io.Reader func NewParser(r io.Reader) *Parser { return &Parser{s: NewScanner(r)} } @@ -26,7 +28,7 @@ func (p *Parser) scan() (tok Token, lit string) { p.buf.n = 0 return p.buf.tok, p.buf.lit } - tok, lit = p.s.Scan() + tok, lit = p.s.scan() p.buf.tok, p.buf.lit = tok, lit @@ -45,7 +47,8 @@ func (p *Parser) scanIgnoreWhitespace() (tok Token, lit string) { return } -func (p *Parser) ParseLine() (*Constraint, error) { +// parseLine parses the current line +func (p *Parser) parseLine() (*Constraint, error) { /* in this version, line will be for example s3 = s1 * s4 @@ -111,12 +114,13 @@ func addToArrayIfNotExist(arr []string, elem string) []string { return arr } +// Parse parses the lines and returns the compiled Circuit func (p *Parser) Parse() (*Circuit, error) { circuit := &Circuit{} circuit.Signals = append(circuit.Signals, "one") nInputs := 0 for { - constraint, err := p.ParseLine() + constraint, err := p.parseLine() if err != nil { break } diff --git a/r1csqap/r1csqap.go b/r1csqap/r1csqap.go index 9cf69c5..8e6458b 100644 --- a/r1csqap/r1csqap.go +++ b/r1csqap/r1csqap.go @@ -6,6 +6,7 @@ import ( "github.com/arnaucube/go-snark/fields" ) +// Transpose transposes the *big.Int matrix func Transpose(matrix [][]*big.Int) [][]*big.Int { var r [][]*big.Int for i := 0; i < len(matrix[0]); i++ { @@ -18,6 +19,7 @@ func Transpose(matrix [][]*big.Int) [][]*big.Int { return r } +// ArrayOfBigZeros creates a *big.Int array with n elements to zero func ArrayOfBigZeros(num int) []*big.Int { bigZero := big.NewInt(int64(0)) var r []*big.Int @@ -27,15 +29,19 @@ func ArrayOfBigZeros(num int) []*big.Int { return r } +// PolynomialField is the Polynomial over a Finite Field where the polynomial operations are performed type PolynomialField struct { F fields.Fq } +// NewPolynomialField creates a new PolynomialField with the given FiniteField func NewPolynomialField(f fields.Fq) PolynomialField { return PolynomialField{ f, } } + +// Mul multiplies two polinomials over the Finite Field func (pf PolynomialField) Mul(a, b []*big.Int) []*big.Int { r := ArrayOfBigZeros(len(a) + len(b) - 1) for i := 0; i < len(a); i++ { @@ -47,6 +53,8 @@ func (pf PolynomialField) Mul(a, b []*big.Int) []*big.Int { } return r } + +// Div divides two polinomials over the Finite Field, returning the result and the remainder func (pf PolynomialField) Div(a, b []*big.Int) ([]*big.Int, []*big.Int) { // https://en.wikipedia.org/wiki/Division_algorithm r := ArrayOfBigZeros(len(a) - len(b) + 1) @@ -70,6 +78,7 @@ func max(a, b int) int { return b } +// Add adds two polinomials over the Finite Field func (pf PolynomialField) Add(a, b []*big.Int) []*big.Int { r := ArrayOfBigZeros(max(len(a), len(b))) for i := 0; i < len(a); i++ { @@ -81,6 +90,7 @@ func (pf PolynomialField) Add(a, b []*big.Int) []*big.Int { return r } +// Sub substracts two polinomials over the Finite Field func (pf PolynomialField) Sub(a, b []*big.Int) []*big.Int { r := ArrayOfBigZeros(max(len(a), len(b))) for i := 0; i < len(a); i++ { @@ -92,6 +102,7 @@ func (pf PolynomialField) Sub(a, b []*big.Int) []*big.Int { return r } +// Eval evaluates the polinomial over the Finite Field at the given value x func (pf PolynomialField) Eval(v []*big.Int, x *big.Int) *big.Int { r := big.NewInt(int64(0)) for i := 0; i < len(v); i++ { @@ -102,6 +113,7 @@ func (pf PolynomialField) Eval(v []*big.Int, x *big.Int) *big.Int { return r } +// NewPolZeroAt generates a new polynomial that has value zero at the given value func (pf PolynomialField) NewPolZeroAt(pointPos, totalPoints int, height *big.Int) []*big.Int { fac := 1 for i := 1; i < totalPoints+1; i++ { @@ -122,6 +134,7 @@ func (pf PolynomialField) NewPolZeroAt(pointPos, totalPoints int, height *big.In return r } +// LagrangeInterpolation performs the Lagrange Interpolation / Lagrange Polynomials operation func (pf PolynomialField) LagrangeInterpolation(v []*big.Int) []*big.Int { // https://en.wikipedia.org/wiki/Lagrange_polynomial var r []*big.Int @@ -132,6 +145,7 @@ func (pf PolynomialField) LagrangeInterpolation(v []*big.Int) []*big.Int { return r } +// R1CSToQAP converts the R1CS values to the QAP values func (pf PolynomialField) R1CSToQAP(a, b, c [][]*big.Int) ([][]*big.Int, [][]*big.Int, [][]*big.Int, []*big.Int) { aT := Transpose(a) bT := Transpose(b) @@ -157,6 +171,7 @@ func (pf PolynomialField) R1CSToQAP(a, b, c [][]*big.Int) ([][]*big.Int, [][]*bi return alphas, betas, gammas, z } +// CombinePolynomials combine the given polynomials arrays into one, also returns the P(x) func (pf PolynomialField) CombinePolynomials(r []*big.Int, ap, bp, cp [][]*big.Int) ([]*big.Int, []*big.Int, []*big.Int, []*big.Int) { var alpha []*big.Int for i := 0; i < len(r); i++ { @@ -178,6 +193,7 @@ func (pf PolynomialField) CombinePolynomials(r []*big.Int, ap, bp, cp [][]*big.I return alpha, beta, gamma, px } +// DivisorPolynomial returns the divisor polynomial given two polynomials func (pf PolynomialField) DivisorPolinomial(px, z []*big.Int) []*big.Int { quo, _ := pf.Div(px, z) return quo diff --git a/snark.go b/snark.go index 3187e2a..43691b8 100644 --- a/snark.go +++ b/snark.go @@ -11,6 +11,7 @@ import ( "github.com/arnaucube/go-snark/r1csqap" ) +// Setup is the data structure holding the Trusted Setup data. The Setup.Toxic sub struct must be destroyed after the GenerateTrustedSetup function is completed type Setup struct { Toxic struct { T *big.Int // trusted setup secret @@ -48,6 +49,7 @@ type Setup struct { } } +// Proof contains the parameters to proof the zkSNARK type Proof struct { PiA [3]*big.Int PiAp [3]*big.Int @@ -60,6 +62,7 @@ type Proof struct { PublicSignals []*big.Int } +// GenerateTrustedSetup generates the Trusted Setup from a compiled Circuit. The Setup.Toxic sub data structure must be destroyed func GenerateTrustedSetup(bn bn128.Bn128, fqR fields.Fq, pf r1csqap.PolynomialField, witnessLength int, circuit circuitcompiler.Circuit, alphas, betas, gammas [][]*big.Int, zx []*big.Int) (Setup, error) { var setup Setup var err error @@ -172,6 +175,7 @@ func GenerateTrustedSetup(bn bn128.Bn128, fqR fields.Fq, pf r1csqap.PolynomialFi return setup, nil } +// GenerateProofs generates all the parameters to proof the zkSNARK from the Circuit, Setup and the Witness func GenerateProofs(bn bn128.Bn128, f fields.Fq, circuit circuitcompiler.Circuit, setup Setup, hx []*big.Int, w []*big.Int) (Proof, error) { var proof Proof proof.PiA = [3]*big.Int{bn.G1.F.Zero(), bn.G1.F.Zero(), bn.G1.F.Zero()} @@ -206,6 +210,7 @@ func GenerateProofs(bn bn128.Bn128, f fields.Fq, circuit circuitcompiler.Circuit return proof, nil } +// VerifyProof verifies over the BN128 the Pairings of the Proof func VerifyProof(bn bn128.Bn128, circuit circuitcompiler.Circuit, setup Setup, proof Proof) bool { // e(piA, Va) == e(piA', g2) diff --git a/snark_test.go b/snark_test.go index 8a166c9..e415304 100644 --- a/snark_test.go +++ b/snark_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestZkFromHardcodedR1CS(t *testing.T) { +func TestZkFromFlatCircuitCode(t *testing.T) { bn, err := bn128.NewBn128() assert.Nil(t, err) @@ -23,41 +23,39 @@ func TestZkFromHardcodedR1CS(t *testing.T) { // new Polynomial Field pf := r1csqap.NewPolynomialField(fqR) - b0 := big.NewInt(int64(0)) - b1 := big.NewInt(int64(1)) + // compile circuit and get the R1CS + flatCode := ` + func test(x): + aux = x*x + y = aux*x + z = x + y + out = z + 5 + ` + fmt.Print("\nflat code of the circuit:") + fmt.Println(flatCode) + + // parse the code + parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) + circuit, err := parser.Parse() + assert.Nil(t, err) + fmt.Println("\ncircuit data:", circuit) + b3 := big.NewInt(int64(3)) - b5 := big.NewInt(int64(5)) - b9 := big.NewInt(int64(9)) - b27 := big.NewInt(int64(27)) - b30 := big.NewInt(int64(30)) - b35 := big.NewInt(int64(35)) - a := [][]*big.Int{ - []*big.Int{b0, b1, b0, b0, b0, b0}, - []*big.Int{b0, b0, b0, b1, b0, b0}, - []*big.Int{b0, b1, b0, b0, b1, b0}, - []*big.Int{b5, b0, b0, b0, b0, b1}, - } - b := [][]*big.Int{ - []*big.Int{b0, b1, b0, b0, b0, b0}, - []*big.Int{b0, b1, b0, b0, b0, b0}, - []*big.Int{b1, b0, b0, b0, b0, b0}, - []*big.Int{b1, b0, b0, b0, b0, b0}, - } - c := [][]*big.Int{ - []*big.Int{b0, b0, b0, b1, b0, b0}, - []*big.Int{b0, b0, b0, b0, b1, b0}, - []*big.Int{b0, b0, b0, b0, b0, b1}, - []*big.Int{b0, b0, b1, b0, b0, b0}, - } + inputs := []*big.Int{b3} + // wittness + w := circuit.CalculateWitness(inputs) + fmt.Println("\nwitness", w) + + // flat code to R1CS + fmt.Println("\ngenerating R1CS from flat code") + a, b, c := circuit.GenerateR1CS() + fmt.Println("\nR1CS:") + fmt.Println("a:", a) + fmt.Println("b:", b) + fmt.Println("c:", c) + alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) - // wittness = 1, 3, 35, 9, 27, 30 - w := []*big.Int{b1, b3, b35, b9, b27, b30} - circuit := circuitcompiler.Circuit{ - NVars: 6, - NPublic: 0, - NSignals: len(w), - } ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas) hx := pf.DivisorPolinomial(px, zx) @@ -76,18 +74,18 @@ func TestZkFromHardcodedR1CS(t *testing.T) { assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4)) // calculate trusted setup - setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), circuit, alphas, betas, gammas, zx) + setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), *circuit, alphas, betas, gammas, zx) assert.Nil(t, err) - fmt.Println("t", setup.Toxic.T) + fmt.Println("\nt:", setup.Toxic.T) // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) - proof, err := GenerateProofs(bn, fqR, circuit, setup, hx, w) + proof, err := GenerateProofs(bn, fqR, *circuit, setup, hx, w) assert.Nil(t, err) - assert.True(t, VerifyProof(bn, circuit, setup, proof)) + assert.True(t, VerifyProof(bn, *circuit, setup, proof)) } -func TestZkFromFlatCircuitCode(t *testing.T) { +func TestZkFromHardcodedR1CS(t *testing.T) { bn, err := bn128.NewBn128() assert.Nil(t, err) @@ -97,34 +95,41 @@ func TestZkFromFlatCircuitCode(t *testing.T) { // new Polynomial Field pf := r1csqap.NewPolynomialField(fqR) - // compile circuit and get the R1CS - flatCode := ` - func test(x): - aux = x*x - y = aux*x - z = x + y - out = z + 5 - ` - // parse the code - parser := circuitcompiler.NewParser(strings.NewReader(flatCode)) - circuit, err := parser.Parse() - assert.Nil(t, err) - fmt.Println(circuit) - // flat code to R1CS - fmt.Println("generating R1CS from flat code") - a, b, c := circuit.GenerateR1CS() - - alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) - - // wittness = 1, 3, 35, 9, 27, 30 + b0 := big.NewInt(int64(0)) b1 := big.NewInt(int64(1)) b3 := big.NewInt(int64(3)) + b5 := big.NewInt(int64(5)) b9 := big.NewInt(int64(9)) b27 := big.NewInt(int64(27)) b30 := big.NewInt(int64(30)) b35 := big.NewInt(int64(35)) - w := []*big.Int{b1, b3, b35, b9, b27, b30} + a := [][]*big.Int{ + []*big.Int{b0, b1, b0, b0, b0, b0}, + []*big.Int{b0, b0, b0, b1, b0, b0}, + []*big.Int{b0, b1, b0, b0, b1, b0}, + []*big.Int{b5, b0, b0, b0, b0, b1}, + } + b := [][]*big.Int{ + []*big.Int{b0, b1, b0, b0, b0, b0}, + []*big.Int{b0, b1, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0}, + []*big.Int{b1, b0, b0, b0, b0, b0}, + } + c := [][]*big.Int{ + []*big.Int{b0, b0, b0, b1, b0, b0}, + []*big.Int{b0, b0, b0, b0, b1, b0}, + []*big.Int{b0, b0, b0, b0, b0, b1}, + []*big.Int{b0, b0, b1, b0, b0, b0}, + } + alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) + // wittness = 1, 3, 35, 9, 27, 30 + w := []*big.Int{b1, b3, b35, b9, b27, b30} + circuit := circuitcompiler.Circuit{ + NVars: 6, + NPublic: 0, + NSignals: len(w), + } ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas) hx := pf.DivisorPolinomial(px, zx) @@ -143,13 +148,13 @@ func TestZkFromFlatCircuitCode(t *testing.T) { assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4)) // calculate trusted setup - setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), *circuit, alphas, betas, gammas, zx) + setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), circuit, alphas, betas, gammas, zx) assert.Nil(t, err) fmt.Println("t", setup.Toxic.T) // piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t) - proof, err := GenerateProofs(bn, fqR, *circuit, setup, hx, w) + proof, err := GenerateProofs(bn, fqR, circuit, setup, hx, w) assert.Nil(t, err) - assert.True(t, VerifyProof(bn, *circuit, setup, proof)) + assert.True(t, VerifyProof(bn, circuit, setup, proof)) }